Browse Source

Use Diesels MultiConnections Derive

With this PR we remove almost all custom macro's to create the multiple database type code. This is now handled by Diesel it self.

This removed the need of the following functions/macro's:
 - `db_object!`
 - `::to_db`
 - `.from_db()`

It is also possible to just use one schema instead of multiple per type.

Also done:
 - Refactored the SQLite backup function
 - Some formatting of queries so every call is one a separate line, this looks a bit better
 - Declare `conn` as mut inside each `db_run!` instead of having to declare it as `mut` in functions or calls
 - Added an `ACTIVE_DB_TYPE` static which holds the currently active database type
 - Removed `diesel_logger` crate and use Diesel's `set_default_instrumentation()`
   If you want debug queries you can now simply change the log level of `vaultwarden::db::query_logger`
 - Use PostgreSQL v17 in the Alpine images to match the Debian Trixie version
 - Optimized the Workflows since `diesel_logger` isn't needed anymore

And on the extra plus-side, this lowers the compile-time and binary size too.

Signed-off-by: BlackDex <black.dex@gmail.com>
pull/6279/head
BlackDex 1 month ago
parent
commit
03aa7e5090
No known key found for this signature in database GPG Key ID: 58C80A2AA6C765E1
  1. 22
      .github/workflows/build.yml
  2. 2
      .github/workflows/hadolint.yml
  3. 8
      .github/workflows/release.yml
  4. 4
      .github/workflows/trivy.yml
  5. 2
      .github/workflows/zizmor.yml
  6. 739
      Cargo.lock
  7. 38
      Cargo.toml
  8. 6
      build.rs
  9. 6
      docker/Dockerfile.alpine
  10. 6
      docker/Dockerfile.j2
  11. 2
      macros/Cargo.toml
  12. 189
      src/api/admin.rs
  13. 256
      src/api/core/accounts.rs
  14. 345
      src/api/core/ciphers.rs
  15. 172
      src/api/core/emergency_access.rs
  16. 41
      src/api/core/events.rs
  17. 35
      src/api/core/folders.rs
  18. 18
      src/api/core/mod.rs
  19. 659
      src/api/core/organizations.rs
  20. 53
      src/api/core/public.rs
  21. 129
      src/api/core/sends.rs
  22. 38
      src/api/core/two_factor/authenticator.rs
  23. 24
      src/api/core/two_factor/duo.rs
  24. 10
      src/api/core/two_factor/duo_oidc.rs
  25. 42
      src/api/core/two_factor/email.rs
  26. 56
      src/api/core/two_factor/mod.rs
  27. 15
      src/api/core/two_factor/protected_actions.rs
  28. 49
      src/api/core/two_factor/webauthn.rs
  29. 18
      src/api/core/two_factor/yubikey.rs
  30. 59
      src/api/identity.rs
  31. 2
      src/api/mod.rs
  32. 20
      src/api/notifications.rs
  33. 26
      src/api/push.rs
  34. 20
      src/auth.rs
  35. 12
      src/config.rs
  36. 379
      src/db/mod.rs
  37. 50
      src/db/models/attachment.rs
  38. 44
      src/db/models/auth_request.rs
  39. 133
      src/db/models/cipher.rs
  40. 128
      src/db/models/collection.rs
  41. 70
      src/db/models/device.rs
  42. 91
      src/db/models/emergency_access.rs
  43. 39
      src/db/models/event.rs
  44. 18
      src/db/models/favorite.rs
  45. 53
      src/db/models/folder.rs
  46. 95
      src/db/models/group.rs
  47. 56
      src/db/models/org_policy.rs
  48. 247
      src/db/models/organization.rs
  49. 56
      src/db/models/send.rs
  50. 16
      src/db/models/sso_nonce.rs
  51. 39
      src/db/models/two_factor.rs
  52. 44
      src/db/models/two_factor_duo_context.rs
  53. 32
      src/db/models/two_factor_incomplete.rs
  54. 89
      src/db/models/user.rs
  55. 57
      src/db/query_logger.rs
  56. 0
      src/db/schema.rs
  57. 395
      src/db/schemas/mysql/schema.rs
  58. 395
      src/db/schemas/sqlite/schema.rs
  59. 50
      src/main.rs
  60. 15
      src/sso.rs

22
.github/workflows/build.yml

@ -69,9 +69,9 @@ jobs:
CHANNEL: ${{ matrix.channel }} CHANNEL: ${{ matrix.channel }}
run: | run: |
if [[ "${CHANNEL}" == 'rust-toolchain' ]]; then if [[ "${CHANNEL}" == 'rust-toolchain' ]]; then
RUST_TOOLCHAIN="$(grep -oP 'channel.*"(\K.*?)(?=")' rust-toolchain.toml)" RUST_TOOLCHAIN="$(grep -m1 -oP 'channel.*"(\K.*?)(?=")' rust-toolchain.toml)"
elif [[ "${CHANNEL}" == 'msrv' ]]; then elif [[ "${CHANNEL}" == 'msrv' ]]; then
RUST_TOOLCHAIN="$(grep -oP 'rust-version.*"(\K.*?)(?=")' Cargo.toml)" RUST_TOOLCHAIN="$(grep -m1 -oP 'rust-version\s.*"(\K.*?)(?=")' Cargo.toml)"
else else
RUST_TOOLCHAIN="${CHANNEL}" RUST_TOOLCHAIN="${CHANNEL}"
fi fi
@ -116,7 +116,7 @@ jobs:
# Enable Rust Caching # Enable Rust Caching
- name: Rust Caching - name: Rust Caching
uses: Swatinem/rust-cache@98c8021b550208e191a6a3145459bfc9fb29c4c0 # v2.8.0 uses: Swatinem/rust-cache@f13886b937689c021905a6b90929199931d60db1 # v2.8.1
with: with:
# Use a custom prefix-key to force a fresh start. This is sometimes needed with bigger changes. # Use a custom prefix-key to force a fresh start. This is sometimes needed with bigger changes.
# Like changing the build host from Ubuntu 20.04 to 22.04 for example. # Like changing the build host from Ubuntu 20.04 to 22.04 for example.
@ -126,18 +126,6 @@ jobs:
# Run cargo tests # Run cargo tests
# First test all features together, afterwards test them separately. # First test all features together, afterwards test them separately.
- name: "test features: sqlite,mysql,postgresql,enable_mimalloc,query_logger"
id: test_sqlite_mysql_postgresql_mimalloc_logger
if: ${{ !cancelled() }}
run: |
cargo test --features sqlite,mysql,postgresql,enable_mimalloc,query_logger
- name: "test features: sqlite,mysql,postgresql,enable_mimalloc"
id: test_sqlite_mysql_postgresql_mimalloc
if: ${{ !cancelled() }}
run: |
cargo test --features sqlite,mysql,postgresql,enable_mimalloc
- name: "test features: sqlite,mysql,postgresql" - name: "test features: sqlite,mysql,postgresql"
id: test_sqlite_mysql_postgresql id: test_sqlite_mysql_postgresql
if: ${{ !cancelled() }} if: ${{ !cancelled() }}
@ -187,8 +175,6 @@ jobs:
- name: "Some checks failed" - name: "Some checks failed"
if: ${{ failure() }} if: ${{ failure() }}
env: env:
TEST_DB_M_L: ${{ steps.test_sqlite_mysql_postgresql_mimalloc_logger.outcome }}
TEST_DB_M: ${{ steps.test_sqlite_mysql_postgresql_mimalloc.outcome }}
TEST_DB: ${{ steps.test_sqlite_mysql_postgresql.outcome }} TEST_DB: ${{ steps.test_sqlite_mysql_postgresql.outcome }}
TEST_SQLITE: ${{ steps.test_sqlite.outcome }} TEST_SQLITE: ${{ steps.test_sqlite.outcome }}
TEST_MYSQL: ${{ steps.test_mysql.outcome }} TEST_MYSQL: ${{ steps.test_mysql.outcome }}
@ -200,8 +186,6 @@ jobs:
echo "" >> "${GITHUB_STEP_SUMMARY}" echo "" >> "${GITHUB_STEP_SUMMARY}"
echo "|Job|Status|" >> "${GITHUB_STEP_SUMMARY}" echo "|Job|Status|" >> "${GITHUB_STEP_SUMMARY}"
echo "|---|------|" >> "${GITHUB_STEP_SUMMARY}" echo "|---|------|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql,enable_mimalloc,query_logger)|${TEST_DB_M_L}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql,enable_mimalloc)|${TEST_DB_M}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite,mysql,postgresql)|${TEST_DB}|" >> "${GITHUB_STEP_SUMMARY}" echo "|test (sqlite,mysql,postgresql)|${TEST_DB}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (sqlite)|${TEST_SQLITE}|" >> "${GITHUB_STEP_SUMMARY}" echo "|test (sqlite)|${TEST_SQLITE}|" >> "${GITHUB_STEP_SUMMARY}"
echo "|test (mysql)|${TEST_MYSQL}|" >> "${GITHUB_STEP_SUMMARY}" echo "|test (mysql)|${TEST_MYSQL}|" >> "${GITHUB_STEP_SUMMARY}"

2
.github/workflows/hadolint.yml

@ -31,7 +31,7 @@ jobs:
sudo curl -L https://github.com/hadolint/hadolint/releases/download/v${HADOLINT_VERSION}/hadolint-$(uname -s)-$(uname -m) -o /usr/local/bin/hadolint && \ sudo curl -L https://github.com/hadolint/hadolint/releases/download/v${HADOLINT_VERSION}/hadolint-$(uname -s)-$(uname -m) -o /usr/local/bin/hadolint && \
sudo chmod +x /usr/local/bin/hadolint sudo chmod +x /usr/local/bin/hadolint
env: env:
HADOLINT_VERSION: 2.12.0 HADOLINT_VERSION: 2.13.1
# End Download hadolint # End Download hadolint
# Checkout the repo # Checkout the repo
- name: Checkout - name: Checkout

8
.github/workflows/release.yml

@ -204,7 +204,7 @@ jobs:
# Attest container images # Attest container images
- name: Attest - docker.io - ${{ matrix.base_image }} - name: Attest - docker.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_DOCKERHUB_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}} if: ${{ env.HAVE_DOCKERHUB_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0 uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with: with:
subject-name: ${{ vars.DOCKERHUB_REPO }} subject-name: ${{ vars.DOCKERHUB_REPO }}
subject-digest: ${{ env.DIGEST_SHA }} subject-digest: ${{ env.DIGEST_SHA }}
@ -212,7 +212,7 @@ jobs:
- name: Attest - ghcr.io - ${{ matrix.base_image }} - name: Attest - ghcr.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_GHCR_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}} if: ${{ env.HAVE_GHCR_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0 uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with: with:
subject-name: ${{ vars.GHCR_REPO }} subject-name: ${{ vars.GHCR_REPO }}
subject-digest: ${{ env.DIGEST_SHA }} subject-digest: ${{ env.DIGEST_SHA }}
@ -220,7 +220,7 @@ jobs:
- name: Attest - quay.io - ${{ matrix.base_image }} - name: Attest - quay.io - ${{ matrix.base_image }}
if: ${{ env.HAVE_QUAY_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}} if: ${{ env.HAVE_QUAY_LOGIN == 'true' && steps.bake_vw.outputs.metadata != ''}}
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0 uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with: with:
subject-name: ${{ vars.QUAY_REPO }} subject-name: ${{ vars.QUAY_REPO }}
subject-digest: ${{ env.DIGEST_SHA }} subject-digest: ${{ env.DIGEST_SHA }}
@ -299,7 +299,7 @@ jobs:
path: vaultwarden-armv6-${{ matrix.base_image }} path: vaultwarden-armv6-${{ matrix.base_image }}
- name: "Attest artifacts ${{ matrix.base_image }}" - name: "Attest artifacts ${{ matrix.base_image }}"
uses: actions/attest-build-provenance@e8998f949152b193b063cb0ec769d69d929409be # v2.4.0 uses: actions/attest-build-provenance@977bb373ede98d70efdf65b84cb5f73e068dcc2a # v3.0.0
with: with:
subject-path: vaultwarden-* subject-path: vaultwarden-*
# End Upload artifacts to Github Actions # End Upload artifacts to Github Actions

4
.github/workflows/trivy.yml

@ -36,7 +36,7 @@ jobs:
persist-credentials: false persist-credentials: false
- name: Run Trivy vulnerability scanner - name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # v0.33.0 + b6643a2 uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # 0.33.1
env: env:
TRIVY_DB_REPOSITORY: docker.io/aquasec/trivy-db:2,public.ecr.aws/aquasecurity/trivy-db:2,ghcr.io/aquasecurity/trivy-db:2 TRIVY_DB_REPOSITORY: docker.io/aquasec/trivy-db:2,public.ecr.aws/aquasecurity/trivy-db:2,ghcr.io/aquasecurity/trivy-db:2
TRIVY_JAVA_DB_REPOSITORY: docker.io/aquasec/trivy-java-db:1,public.ecr.aws/aquasecurity/trivy-java-db:1,ghcr.io/aquasecurity/trivy-java-db:1 TRIVY_JAVA_DB_REPOSITORY: docker.io/aquasec/trivy-java-db:1,public.ecr.aws/aquasecurity/trivy-java-db:1,ghcr.io/aquasecurity/trivy-java-db:1
@ -48,6 +48,6 @@ jobs:
severity: CRITICAL,HIGH severity: CRITICAL,HIGH
- name: Upload Trivy scan results to GitHub Security tab - name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@3c3833e0f8c1c83d449a7478aa59c036a9165498 # v3.29.11 uses: github/codeql-action/upload-sarif@192325c86100d080feab897ff886c34abd4c83a3 # v3.30.3
with: with:
sarif_file: 'trivy-results.sarif' sarif_file: 'trivy-results.sarif'

2
.github/workflows/zizmor.yml

@ -21,7 +21,7 @@ jobs:
persist-credentials: false persist-credentials: false
- name: Run zizmor - name: Run zizmor
uses: zizmorcore/zizmor-action@5ca5fc7a4779c5263a3ffa0e1f693009994446d1 # v0.1.2 uses: zizmorcore/zizmor-action@e673c3917a1aef3c65c972347ed84ccd013ecda4 # v0.2.0
with: with:
# intentionally not scanning the entire repository, # intentionally not scanning the entire repository,
# since it contains integration tests. # since it contains integration tests.

739
Cargo.lock

File diff suppressed because it is too large

38
Cargo.toml

@ -16,7 +16,11 @@ publish = false
build = "build.rs" build = "build.rs"
[features] [features]
# default = ["sqlite"] default = [
# "sqlite",
# "mysql",
# "postgresql",
]
# Empty to keep compatibility, prefer to set USE_SYSLOG=true # Empty to keep compatibility, prefer to set USE_SYSLOG=true
enable_syslog = [] enable_syslog = []
mysql = ["diesel/mysql", "diesel_migrations/mysql"] mysql = ["diesel/mysql", "diesel_migrations/mysql"]
@ -27,11 +31,6 @@ vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc # Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds # This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"] enable_mimalloc = ["dep:mimalloc"]
# This is a development dependency, and should only be used during development!
# It enables the usage of the diesel_logger crate, which is able to output the generated queries.
# You also need to set an env variable `QUERY_LOGGER=1` to fully activate this so you do not have to re-compile
# if you want to turn off the logging for a specific run.
query_logger = ["dep:diesel_logger"]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"] s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"]
# OIDC specific features # OIDC specific features
@ -50,7 +49,7 @@ syslog = "7.0.0"
macros = { path = "./macros" } macros = { path = "./macros" }
# Logging # Logging
log = "0.4.27" log = "0.4.28"
fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] } fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] }
tracing = { version = "0.1.41", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work tracing = { version = "0.1.41", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work
@ -81,13 +80,12 @@ tokio = { version = "1.47.1", features = ["rt-multi-thread", "fs", "io-util", "p
tokio-util = { version = "0.7.16", features = ["compat"]} tokio-util = { version = "0.7.16", features = ["compat"]}
# A generic serialization/deserialization framework # A generic serialization/deserialization framework
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.225", features = ["derive"] }
serde_json = "1.0.143" serde_json = "1.0.145"
# A safe, extensible ORM and Query builder # A safe, extensible ORM and Query builder
diesel = { version = "2.2.12", features = ["chrono", "r2d2", "numeric"] } diesel = { version = "2.3.2", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.2.0" diesel_migrations = "2.3.0"
diesel_logger = { version = "0.4.0", optional = true }
derive_more = { version = "2.0.1", features = ["from", "into", "as_ref", "deref", "display"] } derive_more = { version = "2.0.1", features = ["from", "into", "as_ref", "deref", "display"] }
diesel-derive-newtype = "2.1.2" diesel-derive-newtype = "2.1.2"
@ -101,12 +99,12 @@ ring = "0.17.14"
subtle = "2.6.1" subtle = "2.6.1"
# UUID generation # UUID generation
uuid = { version = "1.18.0", features = ["v4"] } uuid = { version = "1.18.1", features = ["v4"] }
# Date and time libraries # Date and time libraries
chrono = { version = "0.4.41", features = ["clock", "serde"], default-features = false } chrono = { version = "0.4.42", features = ["clock", "serde"], default-features = false }
chrono-tz = "0.10.4" chrono-tz = "0.10.4"
time = "0.3.41" time = "0.3.44"
# Job scheduler # Job scheduler
job_scheduler_ng = "2.3.0" job_scheduler_ng = "2.3.0"
@ -157,7 +155,7 @@ cached = { version = "0.56.0", features = ["async"] }
# Used for custom short lived cookie jar during favicon extraction # Used for custom short lived cookie jar during favicon extraction
cookie = "0.18.1" cookie = "0.18.1"
cookie_store = "0.21.1" cookie_store = "0.22.0"
# Used by U2F, JWT and PostgreSQL # Used by U2F, JWT and PostgreSQL
openssl = "0.10.73" openssl = "0.10.73"
@ -174,7 +172,7 @@ openidconnect = { version = "4.0.1", features = ["reqwest", "native-tls"] }
mini-moka = "0.10.3" mini-moka = "0.10.3"
# Check client versions for specific features. # Check client versions for specific features.
semver = "1.0.26" semver = "1.0.27"
# Allow overriding the default memory allocator # Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow # Mainly used for the musl builds, since the default musl malloc is very slow
@ -195,9 +193,9 @@ grass_compiler = { version = "0.13.4", default-features = false }
opendal = { version = "0.54.0", features = ["services-fs"], default-features = false } opendal = { version = "0.54.0", features = ["services-fs"], default-features = false }
# For retrieving AWS credentials, including temporary SSO credentials # For retrieving AWS credentials, including temporary SSO credentials
anyhow = { version = "1.0.99", optional = true } anyhow = { version = "1.0.100", optional = true }
aws-config = { version = "1.8.5", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true } aws-config = { version = "1.8.6", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true }
aws-credential-types = { version = "1.2.5", optional = true } aws-credential-types = { version = "1.2.6", optional = true }
aws-smithy-runtime-api = { version = "1.9.0", optional = true } aws-smithy-runtime-api = { version = "1.9.0", optional = true }
http = { version = "1.3.1", optional = true } http = { version = "1.3.1", optional = true }
reqsign = { version = "0.16.5", optional = true } reqsign = { version = "0.16.5", optional = true }

6
build.rs

@ -9,8 +9,6 @@ fn main() {
println!("cargo:rustc-cfg=mysql"); println!("cargo:rustc-cfg=mysql");
#[cfg(feature = "postgresql")] #[cfg(feature = "postgresql")]
println!("cargo:rustc-cfg=postgresql"); println!("cargo:rustc-cfg=postgresql");
#[cfg(feature = "query_logger")]
println!("cargo:rustc-cfg=query_logger");
#[cfg(feature = "s3")] #[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3"); println!("cargo:rustc-cfg=s3");
@ -24,7 +22,6 @@ fn main() {
println!("cargo::rustc-check-cfg=cfg(sqlite)"); println!("cargo::rustc-check-cfg=cfg(sqlite)");
println!("cargo::rustc-check-cfg=cfg(mysql)"); println!("cargo::rustc-check-cfg=cfg(mysql)");
println!("cargo::rustc-check-cfg=cfg(postgresql)"); println!("cargo::rustc-check-cfg=cfg(postgresql)");
println!("cargo::rustc-check-cfg=cfg(query_logger)");
println!("cargo::rustc-check-cfg=cfg(s3)"); println!("cargo::rustc-check-cfg=cfg(s3)");
// Rerun when these paths are changed. // Rerun when these paths are changed.
@ -34,9 +31,6 @@ fn main() {
println!("cargo:rerun-if-changed=.git/index"); println!("cargo:rerun-if-changed=.git/index");
println!("cargo:rerun-if-changed=.git/refs/tags"); println!("cargo:rerun-if-changed=.git/refs/tags");
#[cfg(all(not(debug_assertions), feature = "query_logger"))]
compile_error!("Query Logging is only allowed during development, it is not intended for production usage!");
// Support $BWRS_VERSION for legacy compatibility, but default to $VW_VERSION. // Support $BWRS_VERSION for legacy compatibility, but default to $VW_VERSION.
// If neither exist, read from git. // If neither exist, read from git.
let maybe_vaultwarden_version = let maybe_vaultwarden_version =

6
docker/Dockerfile.alpine

@ -53,9 +53,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
TERM=xterm-256color \ TERM=xterm-256color \
CARGO_HOME="/root/.cargo" \ CARGO_HOME="/root/.cargo" \
USER="root" \ USER="root" \
# Use PostgreSQL v15 during Alpine/MUSL builds instead of the default v11 # Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Bookworm already contains libpq v15 # Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq15/lib" PQ_LIB_DIR="/usr/local/musl/pq17/lib"
# Create CARGO_HOME folder and don't download rust docs # Create CARGO_HOME folder and don't download rust docs

6
docker/Dockerfile.j2

@ -63,9 +63,9 @@ ENV DEBIAN_FRONTEND=noninteractive \
CARGO_HOME="/root/.cargo" \ CARGO_HOME="/root/.cargo" \
USER="root" USER="root"
{%- if base == "alpine" %} \ {%- if base == "alpine" %} \
# Use PostgreSQL v15 during Alpine/MUSL builds instead of the default v11 # Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Bookworm already contains libpq v15 # Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq15/lib" PQ_LIB_DIR="/usr/local/musl/pq17/lib"
{% endif %} {% endif %}
{% if base == "debian" %} {% if base == "debian" %}

2
macros/Cargo.toml

@ -10,7 +10,7 @@ proc-macro = true
[dependencies] [dependencies]
quote = "1.0.40" quote = "1.0.40"
syn = "2.0.105" syn = "2.0.106"
[lints] [lints]
workspace = true workspace = true

189
src/api/admin.rs

@ -20,7 +20,14 @@ use crate::{
}, },
auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure}, auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure},
config::ConfigBuilder, config::ConfigBuilder,
db::{backup_database, get_sql_server_version, models::*, DbConn, DbConnType}, db::{
backup_sqlite, get_sql_server_version,
models::{
Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId,
MembershipType, OrgPolicy, OrgPolicyErr, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId,
},
DbConn, DbConnType, ACTIVE_DB_TYPE,
},
error::{Error, MapResult}, error::{Error, MapResult},
http_client::make_http_request, http_client::make_http_request,
mail, mail,
@ -75,18 +82,20 @@ pub fn catchers() -> Vec<Catcher> {
} }
} }
static DB_TYPE: Lazy<&str> = Lazy::new(|| { static DB_TYPE: Lazy<&str> = Lazy::new(|| match ACTIVE_DB_TYPE.get() {
DbConnType::from_url(&CONFIG.database_url()) #[cfg(mysql)]
.map(|t| match t { Some(DbConnType::Mysql) => "MySQL",
DbConnType::sqlite => "SQLite", #[cfg(postgresql)]
DbConnType::mysql => "MySQL", Some(DbConnType::Postgresql) => "PostgreSQL",
DbConnType::postgresql => "PostgreSQL", #[cfg(sqlite)]
}) Some(DbConnType::Sqlite) => "SQLite",
.unwrap_or("Unknown") _ => "Unknown",
}); });
static CAN_BACKUP: Lazy<bool> = #[cfg(sqlite)]
Lazy::new(|| DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::sqlite).unwrap_or(false)); static CAN_BACKUP: Lazy<bool> = Lazy::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false));
#[cfg(not(sqlite))]
static CAN_BACKUP: Lazy<bool> = Lazy::new(|| false);
#[get("/")] #[get("/")]
fn admin_disabled() -> &'static str { fn admin_disabled() -> &'static str {
@ -284,7 +293,7 @@ struct InviteData {
email: String, email: String,
} }
async fn get_user_or_404(user_id: &UserId, conn: &mut DbConn) -> ApiResult<User> { async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult<User> {
if let Some(user) = User::find_by_uuid(user_id, conn).await { if let Some(user) = User::find_by_uuid(user_id, conn).await {
Ok(user) Ok(user)
} else { } else {
@ -293,15 +302,15 @@ async fn get_user_or_404(user_id: &UserId, conn: &mut DbConn) -> ApiResult<User>
} }
#[post("/invite", format = "application/json", data = "<data>")] #[post("/invite", format = "application/json", data = "<data>")]
async fn invite_user(data: Json<InviteData>, _token: AdminToken, mut conn: DbConn) -> JsonResult { async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
let data: InviteData = data.into_inner(); let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &mut conn).await.is_some() { if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code) err_code!("User already exists", Status::Conflict.code)
} }
let mut user = User::new(data.email, None); let mut user = User::new(data.email, None);
async fn _generate_invite(user: &User, conn: &mut DbConn) -> EmptyResult { async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org_id: OrganizationId = FAKE_ADMIN_UUID.to_string().into(); let org_id: OrganizationId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into(); let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
@ -312,10 +321,10 @@ async fn invite_user(data: Json<InviteData>, _token: AdminToken, mut conn: DbCon
} }
} }
_generate_invite(&user, &mut conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; _generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&mut conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?; user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&mut conn).await)) Ok(Json(user.to_json(&conn).await))
} }
#[post("/test/smtp", format = "application/json", data = "<data>")] #[post("/test/smtp", format = "application/json", data = "<data>")]
@ -336,14 +345,14 @@ fn logout(cookies: &CookieJar<'_>) -> Redirect {
} }
#[get("/users")] #[get("/users")]
async fn get_users_json(_token: AdminToken, mut conn: DbConn) -> Json<Value> { async fn get_users_json(_token: AdminToken, conn: DbConn) -> Json<Value> {
let users = User::get_all(&mut conn).await; let users = User::get_all(&conn).await;
let mut users_json = Vec::with_capacity(users.len()); let mut users_json = Vec::with_capacity(users.len());
for (u, _) in users { for (u, _) in users {
let mut usr = u.to_json(&mut conn).await; let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled); usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT)); usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr["lastActive"] = match u.last_active(&mut conn).await { usr["lastActive"] = match u.last_active(&conn).await {
Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)), Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)),
None => json!(None::<String>), None => json!(None::<String>),
}; };
@ -354,17 +363,17 @@ async fn get_users_json(_token: AdminToken, mut conn: DbConn) -> Json<Value> {
} }
#[get("/users/overview")] #[get("/users/overview")]
async fn users_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<String>> { async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let users = User::get_all(&mut conn).await; let users = User::get_all(&conn).await;
let mut users_json = Vec::with_capacity(users.len()); let mut users_json = Vec::with_capacity(users.len());
for (u, sso_u) in users { for (u, sso_u) in users {
let mut usr = u.to_json(&mut conn).await; let mut usr = u.to_json(&conn).await;
usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &mut conn).await); usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn).await);
usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &mut conn).await); usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn).await);
usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &mut conn).await)); usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn).await));
usr["user_enabled"] = json!(u.enabled); usr["user_enabled"] = json!(u.enabled);
usr["created_at"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT)); usr["created_at"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr["last_active"] = match u.last_active(&mut conn).await { usr["last_active"] = match u.last_active(&conn).await {
Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)), Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)),
None => json!("Never"), None => json!("Never"),
}; };
@ -379,9 +388,9 @@ async fn users_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<
} }
#[get("/users/by-mail/<mail>")] #[get("/users/by-mail/<mail>")]
async fn get_user_by_mail_json(mail: &str, _token: AdminToken, mut conn: DbConn) -> JsonResult { async fn get_user_by_mail_json(mail: &str, _token: AdminToken, conn: DbConn) -> JsonResult {
if let Some(u) = User::find_by_mail(mail, &mut conn).await { if let Some(u) = User::find_by_mail(mail, &conn).await {
let mut usr = u.to_json(&mut conn).await; let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled); usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT)); usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
Ok(Json(usr)) Ok(Json(usr))
@ -391,21 +400,21 @@ async fn get_user_by_mail_json(mail: &str, _token: AdminToken, mut conn: DbConn)
} }
#[get("/users/<user_id>")] #[get("/users/<user_id>")]
async fn get_user_json(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> JsonResult { async fn get_user_json(user_id: UserId, _token: AdminToken, conn: DbConn) -> JsonResult {
let u = get_user_or_404(&user_id, &mut conn).await?; let u = get_user_or_404(&user_id, &conn).await?;
let mut usr = u.to_json(&mut conn).await; let mut usr = u.to_json(&conn).await;
usr["userEnabled"] = json!(u.enabled); usr["userEnabled"] = json!(u.enabled);
usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT)); usr["createdAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
Ok(Json(usr)) Ok(Json(usr))
} }
#[post("/users/<user_id>/delete", format = "application/json")] #[post("/users/<user_id>/delete", format = "application/json")]
async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn delete_user(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let user = get_user_or_404(&user_id, &mut conn).await?; let user = get_user_or_404(&user_id, &conn).await?;
// Get the membership records before deleting the actual user // Get the membership records before deleting the actual user
let memberships = Membership::find_any_state_by_user(&user_id, &mut conn).await; let memberships = Membership::find_any_state_by_user(&user_id, &conn).await;
let res = user.delete(&mut conn).await; let res = user.delete(&conn).await;
for membership in memberships { for membership in memberships {
log_event( log_event(
@ -415,7 +424,7 @@ async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> Em
&ACTING_ADMIN_USER.into(), &ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type 14, // Use UnknownBrowser type
&token.ip.ip, &token.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -424,9 +433,9 @@ async fn delete_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> Em
} }
#[delete("/users/<user_id>/sso", format = "application/json")] #[delete("/users/<user_id>/sso", format = "application/json")]
async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn delete_sso_user(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let memberships = Membership::find_any_state_by_user(&user_id, &mut conn).await; let memberships = Membership::find_any_state_by_user(&user_id, &conn).await;
let res = SsoUser::delete(&user_id, &mut conn).await; let res = SsoUser::delete(&user_id, &conn).await;
for membership in memberships { for membership in memberships {
log_event( log_event(
@ -436,7 +445,7 @@ async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -
&ACTING_ADMIN_USER.into(), &ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type 14, // Use UnknownBrowser type
&token.ip.ip, &token.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -445,13 +454,13 @@ async fn delete_sso_user(user_id: UserId, token: AdminToken, mut conn: DbConn) -
} }
#[post("/users/<user_id>/deauth", format = "application/json")] #[post("/users/<user_id>/deauth", format = "application/json")]
async fn deauth_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?; let mut user = get_user_or_404(&user_id, &conn).await?;
nt.send_logout(&user, None, &mut conn).await; nt.send_logout(&user, None, &conn).await;
if CONFIG.push_enabled() { if CONFIG.push_enabled() {
for device in Device::find_push_devices_by_user(&user.uuid, &mut conn).await { for device in Device::find_push_devices_by_user(&user.uuid, &conn).await {
match unregister_push_device(&device.push_uuid).await { match unregister_push_device(&device.push_uuid).await {
Ok(r) => r, Ok(r) => r,
Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"), Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"),
@ -459,46 +468,46 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt:
} }
} }
Device::delete_all_by_user(&user.uuid, &mut conn).await?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
user.save(&mut conn).await user.save(&conn).await
} }
#[post("/users/<user_id>/disable", format = "application/json")] #[post("/users/<user_id>/disable", format = "application/json")]
async fn disable_user(user_id: UserId, _token: AdminToken, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn disable_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?; let mut user = get_user_or_404(&user_id, &conn).await?;
Device::delete_all_by_user(&user.uuid, &mut conn).await?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
user.enabled = false; user.enabled = false;
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await; nt.send_logout(&user, None, &conn).await;
save_result save_result
} }
#[post("/users/<user_id>/enable", format = "application/json")] #[post("/users/<user_id>/enable", format = "application/json")]
async fn enable_user(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn enable_user(user_id: UserId, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?; let mut user = get_user_or_404(&user_id, &conn).await?;
user.enabled = true; user.enabled = true;
user.save(&mut conn).await user.save(&conn).await
} }
#[post("/users/<user_id>/remove-2fa", format = "application/json")] #[post("/users/<user_id>/remove-2fa", format = "application/json")]
async fn remove_2fa(user_id: UserId, token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn remove_2fa(user_id: UserId, token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&user_id, &mut conn).await?; let mut user = get_user_or_404(&user_id, &conn).await?;
TwoFactor::delete_all_by_user(&user.uuid, &mut conn).await?; TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
two_factor::enforce_2fa_policy(&user, &ACTING_ADMIN_USER.into(), 14, &token.ip.ip, &mut conn).await?; two_factor::enforce_2fa_policy(&user, &ACTING_ADMIN_USER.into(), 14, &token.ip.ip, &conn).await?;
user.totp_recover = None; user.totp_recover = None;
user.save(&mut conn).await user.save(&conn).await
} }
#[post("/users/<user_id>/invite/resend", format = "application/json")] #[post("/users/<user_id>/invite/resend", format = "application/json")]
async fn resend_user_invite(user_id: UserId, _token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) -> EmptyResult {
if let Some(user) = User::find_by_uuid(&user_id, &mut conn).await { if let Some(user) = User::find_by_uuid(&user_id, &conn).await {
//TODO: replace this with user.status check when it will be available (PR#3397) //TODO: replace this with user.status check when it will be available (PR#3397)
if !user.password_hash.is_empty() { if !user.password_hash.is_empty() {
err_code!("User already accepted invitation", Status::BadRequest.code); err_code!("User already accepted invitation", Status::BadRequest.code);
@ -524,10 +533,10 @@ struct MembershipTypeData {
} }
#[post("/users/org_type", format = "application/json", data = "<data>")] #[post("/users/org_type", format = "application/json", data = "<data>")]
async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToken, conn: DbConn) -> EmptyResult {
let data: MembershipTypeData = data.into_inner(); let data: MembershipTypeData = data.into_inner();
let Some(mut member_to_edit) = Membership::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &mut conn).await let Some(mut member_to_edit) = Membership::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &conn).await
else { else {
err!("The specified user isn't member of the organization") err!("The specified user isn't member of the organization")
}; };
@ -539,7 +548,7 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner { if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner {
// Removing owner permission, check that there is at least one other confirmed owner // Removing owner permission, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&data.org_uuid, MembershipType::Owner, &mut conn).await <= 1 { if Membership::count_confirmed_by_org_and_type(&data.org_uuid, MembershipType::Owner, &conn).await <= 1 {
err!("Can't change the type of the last owner") err!("Can't change the type of the last owner")
} }
} }
@ -547,11 +556,11 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
// This check is also done at api::organizations::{accept_invite, _confirm_invite, _activate_member, edit_member}, update_membership_type // This check is also done at api::organizations::{accept_invite, _confirm_invite, _activate_member, edit_member}, update_membership_type
// It returns different error messages per function. // It returns different error messages per function.
if new_type < MembershipType::Admin { if new_type < MembershipType::Admin {
match OrgPolicy::is_user_allowed(&member_to_edit.user_uuid, &member_to_edit.org_uuid, true, &mut conn).await { match OrgPolicy::is_user_allowed(&member_to_edit.user_uuid, &member_to_edit.org_uuid, true, &conn).await {
Ok(_) => {} Ok(_) => {}
Err(OrgPolicyErr::TwoFactorMissing) => { Err(OrgPolicyErr::TwoFactorMissing) => {
if CONFIG.email_2fa_auto_fallback() { if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&member_to_edit.user_uuid, &mut conn).await?; two_factor::email::find_and_activate_email_2fa(&member_to_edit.user_uuid, &conn).await?;
} else { } else {
err!("You cannot modify this user to this type because they have not setup 2FA"); err!("You cannot modify this user to this type because they have not setup 2FA");
} }
@ -569,32 +578,32 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
&ACTING_ADMIN_USER.into(), &ACTING_ADMIN_USER.into(),
14, // Use UnknownBrowser type 14, // Use UnknownBrowser type
&token.ip.ip, &token.ip.ip,
&mut conn, &conn,
) )
.await; .await;
member_to_edit.atype = new_type; member_to_edit.atype = new_type;
member_to_edit.save(&mut conn).await member_to_edit.save(&conn).await
} }
#[post("/users/update_revision", format = "application/json")] #[post("/users/update_revision", format = "application/json")]
async fn update_revision_users(_token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn update_revision_users(_token: AdminToken, conn: DbConn) -> EmptyResult {
User::update_all_revisions(&mut conn).await User::update_all_revisions(&conn).await
} }
#[get("/organizations/overview")] #[get("/organizations/overview")]
async fn organizations_overview(_token: AdminToken, mut conn: DbConn) -> ApiResult<Html<String>> { async fn organizations_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let organizations = Organization::get_all(&mut conn).await; let organizations = Organization::get_all(&conn).await;
let mut organizations_json = Vec::with_capacity(organizations.len()); let mut organizations_json = Vec::with_capacity(organizations.len());
for o in organizations { for o in organizations {
let mut org = o.to_json(); let mut org = o.to_json();
org["user_count"] = json!(Membership::count_by_org(&o.uuid, &mut conn).await); org["user_count"] = json!(Membership::count_by_org(&o.uuid, &conn).await);
org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &mut conn).await); org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn).await);
org["collection_count"] = json!(Collection::count_by_org(&o.uuid, &mut conn).await); org["collection_count"] = json!(Collection::count_by_org(&o.uuid, &conn).await);
org["group_count"] = json!(Group::count_by_org(&o.uuid, &mut conn).await); org["group_count"] = json!(Group::count_by_org(&o.uuid, &conn).await);
org["event_count"] = json!(Event::count_by_org(&o.uuid, &mut conn).await); org["event_count"] = json!(Event::count_by_org(&o.uuid, &conn).await);
org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &mut conn).await); org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn).await);
org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &mut conn).await)); org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn).await));
organizations_json.push(org); organizations_json.push(org);
} }
@ -603,9 +612,9 @@ async fn organizations_overview(_token: AdminToken, mut conn: DbConn) -> ApiResu
} }
#[post("/organizations/<org_id>/delete", format = "application/json")] #[post("/organizations/<org_id>/delete", format = "application/json")]
async fn delete_organization(org_id: OrganizationId, _token: AdminToken, mut conn: DbConn) -> EmptyResult { async fn delete_organization(org_id: OrganizationId, _token: AdminToken, conn: DbConn) -> EmptyResult {
let org = Organization::find_by_uuid(&org_id, &mut conn).await.map_res("Organization doesn't exist")?; let org = Organization::find_by_uuid(&org_id, &conn).await.map_res("Organization doesn't exist")?;
org.delete(&mut conn).await org.delete(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -693,7 +702,7 @@ async fn get_ntp_time(has_http_access: bool) -> String {
} }
#[get("/diagnostics")] #[get("/diagnostics")]
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) -> ApiResult<Html<String>> { async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
use chrono::prelude::*; use chrono::prelude::*;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
@ -747,7 +756,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn)
"uses_proxy": uses_proxy, "uses_proxy": uses_proxy,
"enable_websocket": &CONFIG.enable_websocket(), "enable_websocket": &CONFIG.enable_websocket(),
"db_type": *DB_TYPE, "db_type": *DB_TYPE,
"db_version": get_sql_server_version(&mut conn).await, "db_version": get_sql_server_version(&conn).await,
"admin_url": format!("{}/diagnostics", admin_url()), "admin_url": format!("{}/diagnostics", admin_url()),
"overrides": &CONFIG.get_overrides().join(", "), "overrides": &CONFIG.get_overrides().join(", "),
"host_arch": env::consts::ARCH, "host_arch": env::consts::ARCH,
@ -791,9 +800,9 @@ async fn delete_config(_token: AdminToken) -> EmptyResult {
} }
#[post("/config/backup_db", format = "application/json")] #[post("/config/backup_db", format = "application/json")]
async fn backup_db(_token: AdminToken, mut conn: DbConn) -> ApiResult<String> { fn backup_db(_token: AdminToken) -> ApiResult<String> {
if *CAN_BACKUP { if *CAN_BACKUP {
match backup_database(&mut conn).await { match backup_sqlite() {
Ok(f) => Ok(format!("Backup to '{f}' was successful")), Ok(f) => Ok(format!("Backup to '{f}' was successful")),
Err(e) => err!(format!("Backup was unsuccessful {e}")), Err(e) => err!(format!("Backup was unsuccessful {e}")),
} }

256
src/api/core/accounts.rs

@ -13,7 +13,14 @@ use crate::{
}, },
auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers}, auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers},
crypto, crypto,
db::{models::*, DbConn}, db::{
models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess,
EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy,
OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
},
DbConn,
},
mail, mail,
util::{format_date, NumberOrString}, util::{format_date, NumberOrString},
CONFIG, CONFIG,
@ -142,7 +149,7 @@ fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult
} }
Ok(()) Ok(())
} }
async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &mut DbConn) -> bool { async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &DbConn) -> bool {
if !CONFIG._enable_email_2fa() { if !CONFIG._enable_email_2fa() {
return false; return false;
} }
@ -160,7 +167,7 @@ async fn register(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, false, conn).await _register(data, false, conn).await
} }
pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut conn: DbConn) -> JsonResult { pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
let mut data: RegisterData = data.into_inner(); let mut data: RegisterData = data.into_inner();
let email = data.email.to_lowercase(); let email = data.email.to_lowercase();
@ -242,7 +249,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
let password_hint = clean_password_hint(&data.master_password_hint); let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?; enforce_password_hint_setting(&password_hint)?;
let mut user = match User::find_by_mail(&email, &mut conn).await { let mut user = match User::find_by_mail(&email, &conn).await {
Some(user) => { Some(user) => {
if !user.password_hash.is_empty() { if !user.password_hash.is_empty() {
err!("Registration not allowed or user already exists") err!("Registration not allowed or user already exists")
@ -257,12 +264,12 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
} else { } else {
err!("Registration email does not match invite email") err!("Registration email does not match invite email")
} }
} else if Invitation::take(&email, &mut conn).await { } else if Invitation::take(&email, &conn).await {
Membership::accept_user_invitations(&user.uuid, &mut conn).await?; Membership::accept_user_invitations(&user.uuid, &conn).await?;
user user
} else if CONFIG.is_signup_allowed(&email) } else if CONFIG.is_signup_allowed(&email)
|| (CONFIG.emergency_access_allowed() || (CONFIG.emergency_access_allowed()
&& EmergencyAccess::find_invited_by_grantee_email(&email, &mut conn).await.is_some()) && EmergencyAccess::find_invited_by_grantee_email(&email, &conn).await.is_some())
{ {
user user
} else { } else {
@ -273,7 +280,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
// Order is important here; the invitation check must come first // Order is important here; the invitation check must come first
// because the vaultwarden admin can invite anyone, regardless // because the vaultwarden admin can invite anyone, regardless
// of other signup restrictions. // of other signup restrictions.
if Invitation::take(&email, &mut conn).await if Invitation::take(&email, &conn).await
|| CONFIG.is_signup_allowed(&email) || CONFIG.is_signup_allowed(&email)
|| pending_emergency_access.is_some() || pending_emergency_access.is_some()
{ {
@ -285,7 +292,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
}; };
// Make sure we don't leave a lingering invitation. // Make sure we don't leave a lingering invitation.
Invitation::take(&email, &mut conn).await; Invitation::take(&email, &conn).await;
set_kdf_data(&mut user, data.kdf)?; set_kdf_data(&mut user, data.kdf)?;
@ -316,17 +323,17 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
error!("Error sending welcome email: {e:#?}"); error!("Error sending welcome email: {e:#?}");
} }
if email_verified && is_email_2fa_required(data.organization_user_id, &mut conn).await { if email_verified && is_email_2fa_required(data.organization_user_id, &conn).await {
email::activate_email_2fa(&user, &mut conn).await.ok(); email::activate_email_2fa(&user, &conn).await.ok();
} }
} }
user.save(&mut conn).await?; user.save(&conn).await?;
// accept any open emergency access invitations // accept any open emergency access invitations
if !CONFIG.mail_enabled() && CONFIG.emergency_access_allowed() { if !CONFIG.mail_enabled() && CONFIG.emergency_access_allowed() {
for mut emergency_invite in EmergencyAccess::find_all_invited_by_grantee_email(&user.email, &mut conn).await { for mut emergency_invite in EmergencyAccess::find_all_invited_by_grantee_email(&user.email, &conn).await {
emergency_invite.accept_invite(&user.uuid, &user.email, &mut conn).await.ok(); emergency_invite.accept_invite(&user.uuid, &user.email, &conn).await.ok();
} }
} }
@ -337,7 +344,7 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, mut c
} }
#[post("/accounts/set-password", data = "<data>")] #[post("/accounts/set-password", data = "<data>")]
async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: SetPasswordData = data.into_inner(); let data: SetPasswordData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -367,30 +374,30 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut co
if let Some(identifier) = data.org_identifier { if let Some(identifier) = data.org_identifier {
if identifier != crate::sso::FAKE_IDENTIFIER { if identifier != crate::sso::FAKE_IDENTIFIER {
let org = match Organization::find_by_name(&identifier, &mut conn).await { let org = match Organization::find_by_name(&identifier, &conn).await {
None => err!("Failed to retrieve the associated organization"), None => err!("Failed to retrieve the associated organization"),
Some(org) => org, Some(org) => org,
}; };
let membership = match Membership::find_by_user_and_org(&user.uuid, &org.uuid, &mut conn).await { let membership = match Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await {
None => err!("Failed to retrieve the invitation"), None => err!("Failed to retrieve the invitation"),
Some(org) => org, Some(org) => org,
}; };
accept_org_invite(&user, membership, None, &mut conn).await?; accept_org_invite(&user, membership, None, &conn).await?;
} }
} }
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_welcome(&user.email.to_lowercase()).await?; mail::send_welcome(&user.email.to_lowercase()).await?;
} else { } else {
Membership::accept_user_invitations(&user.uuid, &mut conn).await?; Membership::accept_user_invitations(&user.uuid, &conn).await?;
} }
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn) log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await; .await;
user.save(&mut conn).await?; user.save(&conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"Object": "set-password", "Object": "set-password",
@ -399,8 +406,8 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, mut co
} }
#[get("/accounts/profile")] #[get("/accounts/profile")]
async fn profile(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn profile(headers: Headers, conn: DbConn) -> Json<Value> {
Json(headers.user.to_json(&mut conn).await) Json(headers.user.to_json(&conn).await)
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -416,7 +423,7 @@ async fn put_profile(data: Json<ProfileData>, headers: Headers, conn: DbConn) ->
} }
#[post("/accounts/profile", data = "<data>")] #[post("/accounts/profile", data = "<data>")]
async fn post_profile(data: Json<ProfileData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn post_profile(data: Json<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: ProfileData = data.into_inner(); let data: ProfileData = data.into_inner();
// Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden) // Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
@ -428,8 +435,8 @@ async fn post_profile(data: Json<ProfileData>, headers: Headers, mut conn: DbCon
let mut user = headers.user; let mut user = headers.user;
user.name = data.name; user.name = data.name;
user.save(&mut conn).await?; user.save(&conn).await?;
Ok(Json(user.to_json(&mut conn).await)) Ok(Json(user.to_json(&conn).await))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -439,7 +446,7 @@ struct AvatarData {
} }
#[put("/accounts/avatar", data = "<data>")] #[put("/accounts/avatar", data = "<data>")]
async fn put_avatar(data: Json<AvatarData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn put_avatar(data: Json<AvatarData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: AvatarData = data.into_inner(); let data: AvatarData = data.into_inner();
// It looks like it only supports the 6 hex color format. // It looks like it only supports the 6 hex color format.
@ -454,13 +461,13 @@ async fn put_avatar(data: Json<AvatarData>, headers: Headers, mut conn: DbConn)
let mut user = headers.user; let mut user = headers.user;
user.avatar_color = data.avatar_color; user.avatar_color = data.avatar_color;
user.save(&mut conn).await?; user.save(&conn).await?;
Ok(Json(user.to_json(&mut conn).await)) Ok(Json(user.to_json(&conn).await))
} }
#[get("/users/<user_id>/public-key")] #[get("/users/<user_id>/public-key")]
async fn get_public_keys(user_id: UserId, _headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_public_keys(user_id: UserId, _headers: Headers, conn: DbConn) -> JsonResult {
let user = match User::find_by_uuid(&user_id, &mut conn).await { let user = match User::find_by_uuid(&user_id, &conn).await {
Some(user) if user.public_key.is_some() => user, Some(user) if user.public_key.is_some() => user,
Some(_) => err_code!("User has no public_key", Status::NotFound.code), Some(_) => err_code!("User has no public_key", Status::NotFound.code),
None => err_code!("User doesn't exist", Status::NotFound.code), None => err_code!("User doesn't exist", Status::NotFound.code),
@ -474,7 +481,7 @@ async fn get_public_keys(user_id: UserId, _headers: Headers, mut conn: DbConn) -
} }
#[post("/accounts/keys", data = "<data>")] #[post("/accounts/keys", data = "<data>")]
async fn post_keys(data: Json<KeysData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn post_keys(data: Json<KeysData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: KeysData = data.into_inner(); let data: KeysData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -482,7 +489,7 @@ async fn post_keys(data: Json<KeysData>, headers: Headers, mut conn: DbConn) ->
user.private_key = Some(data.encrypted_private_key); user.private_key = Some(data.encrypted_private_key);
user.public_key = Some(data.public_key); user.public_key = Some(data.public_key);
user.save(&mut conn).await?; user.save(&conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"privateKey": user.private_key, "privateKey": user.private_key,
@ -501,7 +508,7 @@ struct ChangePassData {
} }
#[post("/accounts/password", data = "<data>")] #[post("/accounts/password", data = "<data>")]
async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn post_password(data: Json<ChangePassData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: ChangePassData = data.into_inner(); let data: ChangePassData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -512,7 +519,7 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: D
user.password_hint = clean_password_hint(&data.master_password_hint); user.password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&user.password_hint)?; enforce_password_hint_setting(&user.password_hint)?;
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn) log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await; .await;
user.set_password( user.set_password(
@ -527,12 +534,12 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, mut conn: D
]), ]),
); );
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
// Prevent logging out the client where the user requested this endpoint from. // Prevent logging out the client where the user requested this endpoint from.
// If you do logout the user it will causes issues at the client side. // If you do logout the user it will causes issues at the client side.
// Adding the device uuid will prevent this. // Adding the device uuid will prevent this.
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await; nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result save_result
} }
@ -584,7 +591,7 @@ fn set_kdf_data(user: &mut User, data: KDFData) -> EmptyResult {
} }
#[post("/accounts/kdf", data = "<data>")] #[post("/accounts/kdf", data = "<data>")]
async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: ChangeKdfData = data.into_inner(); let data: ChangeKdfData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -595,9 +602,9 @@ async fn post_kdf(data: Json<ChangeKdfData>, headers: Headers, mut conn: DbConn,
set_kdf_data(&mut user, data.kdf)?; set_kdf_data(&mut user, data.kdf)?;
user.set_password(&data.new_master_password_hash, Some(data.key), true, None); user.set_password(&data.new_master_password_hash, Some(data.key), true, None);
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await; nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result save_result
} }
@ -752,7 +759,7 @@ fn validate_keydata(
} }
#[post("/accounts/key-management/rotate-user-account-keys", data = "<data>")] #[post("/accounts/key-management/rotate-user-account-keys", data = "<data>")]
async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
// TODO: See if we can wrap everything within a SQL Transaction. If something fails it should revert everything. // TODO: See if we can wrap everything within a SQL Transaction. If something fails it should revert everything.
let data: KeyData = data.into_inner(); let data: KeyData = data.into_inner();
@ -770,13 +777,13 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
// TODO: Ideally we'd do everything after this point in a single transaction. // TODO: Ideally we'd do everything after this point in a single transaction.
let mut existing_ciphers = Cipher::find_owned_by_user(user_id, &mut conn).await; let mut existing_ciphers = Cipher::find_owned_by_user(user_id, &conn).await;
let mut existing_folders = Folder::find_by_user(user_id, &mut conn).await; let mut existing_folders = Folder::find_by_user(user_id, &conn).await;
let mut existing_emergency_access = EmergencyAccess::find_all_by_grantor_uuid(user_id, &mut conn).await; let mut existing_emergency_access = EmergencyAccess::find_all_by_grantor_uuid(user_id, &conn).await;
let mut existing_memberships = Membership::find_by_user(user_id, &mut conn).await; let mut existing_memberships = Membership::find_by_user(user_id, &conn).await;
// We only rotate the reset password key if it is set. // We only rotate the reset password key if it is set.
existing_memberships.retain(|m| m.reset_password_key.is_some()); existing_memberships.retain(|m| m.reset_password_key.is_some());
let mut existing_sends = Send::find_by_user(user_id, &mut conn).await; let mut existing_sends = Send::find_by_user(user_id, &conn).await;
validate_keydata( validate_keydata(
&data, &data,
@ -798,7 +805,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
}; };
saved_folder.name = folder_data.name; saved_folder.name = folder_data.name;
saved_folder.save(&mut conn).await? saved_folder.save(&conn).await?
} }
} }
@ -811,7 +818,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
}; };
saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted); saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted);
saved_emergency_access.save(&mut conn).await? saved_emergency_access.save(&conn).await?
} }
// Update reset password data // Update reset password data
@ -823,7 +830,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
}; };
membership.reset_password_key = Some(reset_password_data.reset_password_key); membership.reset_password_key = Some(reset_password_data.reset_password_key);
membership.save(&mut conn).await? membership.save(&conn).await?
} }
// Update send data // Update send data
@ -832,7 +839,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
err!("Send doesn't exist") err!("Send doesn't exist")
}; };
update_send_from_data(send, send_data, &headers, &mut conn, &nt, UpdateType::None).await?; update_send_from_data(send, send_data, &headers, &conn, &nt, UpdateType::None).await?;
} }
// Update cipher data // Update cipher data
@ -848,7 +855,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None // Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues. // The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
// We force the users to logout after the user has been saved to try and prevent these issues. // We force the users to logout after the user has been saved to try and prevent these issues.
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &mut conn, &nt, UpdateType::None).await? update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?
} }
} }
@ -863,28 +870,28 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, mut conn: DbConn,
None, None,
); );
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
// Prevent logging out the client where the user requested this endpoint from. // Prevent logging out the client where the user requested this endpoint from.
// If you do logout the user it will causes issues at the client side. // If you do logout the user it will causes issues at the client side.
// Adding the device uuid will prevent this. // Adding the device uuid will prevent this.
nt.send_logout(&user, Some(headers.device.uuid.clone()), &mut conn).await; nt.send_logout(&user, Some(headers.device.uuid.clone()), &conn).await;
save_result save_result
} }
#[post("/accounts/security-stamp", data = "<data>")] #[post("/accounts/security-stamp", data = "<data>")]
async fn post_sstamp(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn post_sstamp(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
data.validate(&user, true, &mut conn).await?; data.validate(&user, true, &conn).await?;
Device::delete_all_by_user(&user.uuid, &mut conn).await?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await; nt.send_logout(&user, None, &conn).await;
save_result save_result
} }
@ -897,7 +904,7 @@ struct EmailTokenData {
} }
#[post("/accounts/email-token", data = "<data>")] #[post("/accounts/email-token", data = "<data>")]
async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.email_change_allowed() { if !CONFIG.email_change_allowed() {
err!("Email change is not allowed."); err!("Email change is not allowed.");
} }
@ -909,7 +916,7 @@ async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn
err!("Invalid password") err!("Invalid password")
} }
if User::find_by_mail(&data.new_email, &mut conn).await.is_some() { if User::find_by_mail(&data.new_email, &conn).await.is_some() {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Err(e) = mail::send_change_email_existing(&data.new_email, &user.email).await { if let Err(e) = mail::send_change_email_existing(&data.new_email, &user.email).await {
error!("Error sending change-email-existing email: {e:#?}"); error!("Error sending change-email-existing email: {e:#?}");
@ -934,7 +941,7 @@ async fn post_email_token(data: Json<EmailTokenData>, headers: Headers, mut conn
user.email_new = Some(data.new_email); user.email_new = Some(data.new_email);
user.email_new_token = Some(token); user.email_new_token = Some(token);
user.save(&mut conn).await user.save(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -949,7 +956,7 @@ struct ChangeEmailData {
} }
#[post("/accounts/email", data = "<data>")] #[post("/accounts/email", data = "<data>")]
async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn post_email(data: Json<ChangeEmailData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
if !CONFIG.email_change_allowed() { if !CONFIG.email_change_allowed() {
err!("Email change is not allowed."); err!("Email change is not allowed.");
} }
@ -961,7 +968,7 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbC
err!("Invalid password") err!("Invalid password")
} }
if User::find_by_mail(&data.new_email, &mut conn).await.is_some() { if User::find_by_mail(&data.new_email, &conn).await.is_some() {
err!("Email already in use"); err!("Email already in use");
} }
@ -995,9 +1002,9 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, mut conn: DbC
user.set_password(&data.new_master_password_hash, Some(data.key), true, None); user.set_password(&data.new_master_password_hash, Some(data.key), true, None);
let save_result = user.save(&mut conn).await; let save_result = user.save(&conn).await;
nt.send_logout(&user, None, &mut conn).await; nt.send_logout(&user, None, &conn).await;
save_result save_result
} }
@ -1025,10 +1032,10 @@ struct VerifyEmailTokenData {
} }
#[post("/accounts/verify-email-token", data = "<data>")] #[post("/accounts/verify-email-token", data = "<data>")]
async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, mut conn: DbConn) -> EmptyResult { async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, conn: DbConn) -> EmptyResult {
let data: VerifyEmailTokenData = data.into_inner(); let data: VerifyEmailTokenData = data.into_inner();
let Some(mut user) = User::find_by_uuid(&data.user_id, &mut conn).await else { let Some(mut user) = User::find_by_uuid(&data.user_id, &conn).await else {
err!("User doesn't exist") err!("User doesn't exist")
}; };
@ -1041,7 +1048,7 @@ async fn post_verify_email_token(data: Json<VerifyEmailTokenData>, mut conn: DbC
user.verified_at = Some(Utc::now().naive_utc()); user.verified_at = Some(Utc::now().naive_utc());
user.last_verifying_at = None; user.last_verifying_at = None;
user.login_verify_count = 0; user.login_verify_count = 0;
if let Err(e) = user.save(&mut conn).await { if let Err(e) = user.save(&conn).await {
error!("Error saving email verification: {e:#?}"); error!("Error saving email verification: {e:#?}");
} }
@ -1055,11 +1062,11 @@ struct DeleteRecoverData {
} }
#[post("/accounts/delete-recover", data = "<data>")] #[post("/accounts/delete-recover", data = "<data>")]
async fn post_delete_recover(data: Json<DeleteRecoverData>, mut conn: DbConn) -> EmptyResult { async fn post_delete_recover(data: Json<DeleteRecoverData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverData = data.into_inner(); let data: DeleteRecoverData = data.into_inner();
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Some(user) = User::find_by_mail(&data.email, &mut conn).await { if let Some(user) = User::find_by_mail(&data.email, &conn).await {
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await { if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await {
error!("Error sending delete account email: {e:#?}"); error!("Error sending delete account email: {e:#?}");
} }
@ -1082,21 +1089,21 @@ struct DeleteRecoverTokenData {
} }
#[post("/accounts/delete-recover-token", data = "<data>")] #[post("/accounts/delete-recover-token", data = "<data>")]
async fn post_delete_recover_token(data: Json<DeleteRecoverTokenData>, mut conn: DbConn) -> EmptyResult { async fn post_delete_recover_token(data: Json<DeleteRecoverTokenData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverTokenData = data.into_inner(); let data: DeleteRecoverTokenData = data.into_inner();
let Ok(claims) = decode_delete(&data.token) else { let Ok(claims) = decode_delete(&data.token) else {
err!("Invalid claim") err!("Invalid claim")
}; };
let Some(user) = User::find_by_uuid(&data.user_id, &mut conn).await else { let Some(user) = User::find_by_uuid(&data.user_id, &conn).await else {
err!("User doesn't exist") err!("User doesn't exist")
}; };
if claims.sub != *user.uuid { if claims.sub != *user.uuid {
err!("Invalid claim"); err!("Invalid claim");
} }
user.delete(&mut conn).await user.delete(&conn).await
} }
#[post("/accounts/delete", data = "<data>")] #[post("/accounts/delete", data = "<data>")]
@ -1105,13 +1112,13 @@ async fn post_delete_account(data: Json<PasswordOrOtpData>, headers: Headers, co
} }
#[delete("/accounts", data = "<data>")] #[delete("/accounts", data = "<data>")]
async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, true, &mut conn).await?; data.validate(&user, true, &conn).await?;
user.delete(&mut conn).await user.delete(&conn).await
} }
#[get("/accounts/revision-date")] #[get("/accounts/revision-date")]
@ -1127,7 +1134,7 @@ struct PasswordHintData {
} }
#[post("/accounts/password-hint", data = "<data>")] #[post("/accounts/password-hint", data = "<data>")]
async fn password_hint(data: Json<PasswordHintData>, mut conn: DbConn) -> EmptyResult { async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult {
if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) { if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) {
err!("This server is not configured to provide password hints."); err!("This server is not configured to provide password hints.");
} }
@ -1137,7 +1144,7 @@ async fn password_hint(data: Json<PasswordHintData>, mut conn: DbConn) -> EmptyR
let data: PasswordHintData = data.into_inner(); let data: PasswordHintData = data.into_inner();
let email = &data.email; let email = &data.email;
match User::find_by_mail(email, &mut conn).await { match User::find_by_mail(email, &conn).await {
None => { None => {
// To prevent user enumeration, act as if the user exists. // To prevent user enumeration, act as if the user exists.
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
@ -1179,10 +1186,10 @@ async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await _prelogin(data, conn).await
} }
pub async fn _prelogin(data: Json<PreloginData>, mut conn: DbConn) -> Json<Value> { pub async fn _prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner(); let data: PreloginData = data.into_inner();
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &mut conn).await { let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await {
Some(user) => (user.client_kdf_type, user.client_kdf_iter, user.client_kdf_memory, user.client_kdf_parallelism), Some(user) => (user.client_kdf_type, user.client_kdf_iter, user.client_kdf_memory, user.client_kdf_parallelism),
None => (User::CLIENT_KDF_TYPE_DEFAULT, User::CLIENT_KDF_ITER_DEFAULT, None, None), None => (User::CLIENT_KDF_TYPE_DEFAULT, User::CLIENT_KDF_ITER_DEFAULT, None, None),
}; };
@ -1203,7 +1210,7 @@ struct SecretVerificationRequest {
} }
// Change the KDF Iterations if necessary // Change the KDF Iterations if necessary
pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &mut DbConn) -> ApiResult<()> { pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &DbConn) -> ApiResult<()> {
if user.password_iterations < CONFIG.password_iterations() { if user.password_iterations < CONFIG.password_iterations() {
user.password_iterations = CONFIG.password_iterations(); user.password_iterations = CONFIG.password_iterations();
user.set_password(pwd_hash, None, false, None); user.set_password(pwd_hash, None, false, None);
@ -1216,7 +1223,7 @@ pub async fn kdf_upgrade(user: &mut User, pwd_hash: &str, conn: &mut DbConn) ->
} }
#[post("/accounts/verify-password", data = "<data>")] #[post("/accounts/verify-password", data = "<data>")]
async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult {
let data: SecretVerificationRequest = data.into_inner(); let data: SecretVerificationRequest = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -1224,22 +1231,22 @@ async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers
err!("Invalid password") err!("Invalid password")
} }
kdf_upgrade(&mut user, &data.master_password_hash, &mut conn).await?; kdf_upgrade(&mut user, &data.master_password_hash, &conn).await?;
Ok(Json(master_password_policy(&user, &conn).await)) Ok(Json(master_password_policy(&user, &conn).await))
} }
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, mut conn: DbConn) -> JsonResult { async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
use crate::util::format_date; use crate::util::format_date;
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
data.validate(&user, true, &mut conn).await?; data.validate(&user, true, &conn).await?;
if rotate || user.api_key.is_none() { if rotate || user.api_key.is_none() {
user.api_key = Some(crypto::generate_api_key()); user.api_key = Some(crypto::generate_api_key());
user.save(&mut conn).await.expect("Error saving API key"); user.save(&conn).await.expect("Error saving API key");
} }
Ok(Json(json!({ Ok(Json(json!({
@ -1260,10 +1267,10 @@ async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: D
} }
#[get("/devices/knowndevice")] #[get("/devices/knowndevice")]
async fn get_known_device(device: KnownDevice, mut conn: DbConn) -> JsonResult { async fn get_known_device(device: KnownDevice, conn: DbConn) -> JsonResult {
let mut result = false; let mut result = false;
if let Some(user) = User::find_by_mail(&device.email, &mut conn).await { if let Some(user) = User::find_by_mail(&device.email, &conn).await {
result = Device::find_by_uuid_and_user(&device.uuid, &user.uuid, &mut conn).await.is_some(); result = Device::find_by_uuid_and_user(&device.uuid, &user.uuid, &conn).await.is_some();
} }
Ok(Json(json!(result))) Ok(Json(json!(result)))
} }
@ -1306,8 +1313,8 @@ impl<'r> FromRequest<'r> for KnownDevice {
} }
#[get("/devices")] #[get("/devices")]
async fn get_all_devices(headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &mut conn).await; let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await;
let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>(); let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>();
Ok(Json(json!({ Ok(Json(json!({
@ -1318,8 +1325,8 @@ async fn get_all_devices(headers: Headers, mut conn: DbConn) -> JsonResult {
} }
#[get("/devices/identifier/<device_id>")] #[get("/devices/identifier/<device_id>")]
async fn get_device(device_id: DeviceId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_device(device_id: DeviceId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(device) = Device::find_by_uuid_and_user(&device_id, &headers.user.uuid, &mut conn).await else { let Some(device) = Device::find_by_uuid_and_user(&device_id, &headers.user.uuid, &conn).await else {
err!("No device found"); err!("No device found");
}; };
Ok(Json(device.to_json())) Ok(Json(device.to_json()))
@ -1337,17 +1344,11 @@ async fn post_device_token(device_id: DeviceId, data: Json<PushToken>, headers:
} }
#[put("/devices/identifier/<device_id>/token", data = "<data>")] #[put("/devices/identifier/<device_id>/token", data = "<data>")]
async fn put_device_token( async fn put_device_token(device_id: DeviceId, data: Json<PushToken>, headers: Headers, conn: DbConn) -> EmptyResult {
device_id: DeviceId,
data: Json<PushToken>,
headers: Headers,
mut conn: DbConn,
) -> EmptyResult {
let data = data.into_inner(); let data = data.into_inner();
let token = data.push_token; let token = data.push_token;
let Some(mut device) = Device::find_by_uuid_and_user(&headers.device.uuid, &headers.user.uuid, &mut conn).await let Some(mut device) = Device::find_by_uuid_and_user(&headers.device.uuid, &headers.user.uuid, &conn).await else {
else {
err!(format!("Error: device {device_id} should be present before a token can be assigned")) err!(format!("Error: device {device_id} should be present before a token can be assigned"))
}; };
@ -1360,17 +1361,17 @@ async fn put_device_token(
} }
device.push_token = Some(token); device.push_token = Some(token);
if let Err(e) = device.save(&mut conn).await { if let Err(e) = device.save(&conn).await {
err!(format!("An error occurred while trying to save the device push token: {e}")); err!(format!("An error occurred while trying to save the device push token: {e}"));
} }
register_push_device(&mut device, &mut conn).await?; register_push_device(&mut device, &conn).await?;
Ok(()) Ok(())
} }
#[put("/devices/identifier/<device_id>/clear-token")] #[put("/devices/identifier/<device_id>/clear-token")]
async fn put_clear_device_token(device_id: DeviceId, mut conn: DbConn) -> EmptyResult { async fn put_clear_device_token(device_id: DeviceId, conn: DbConn) -> EmptyResult {
// This only clears push token // This only clears push token
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Controllers/DevicesController.cs#L215 // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Controllers/DevicesController.cs#L215
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/Services/Implementations/DeviceService.cs#L37 // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/Services/Implementations/DeviceService.cs#L37
@ -1382,8 +1383,8 @@ async fn put_clear_device_token(device_id: DeviceId, mut conn: DbConn) -> EmptyR
return Ok(()); return Ok(());
} }
if let Some(device) = Device::find_by_uuid(&device_id, &mut conn).await { if let Some(device) = Device::find_by_uuid(&device_id, &conn).await {
Device::clear_push_token_by_uuid(&device_id, &mut conn).await?; Device::clear_push_token_by_uuid(&device_id, &conn).await?;
unregister_push_device(&device.push_uuid).await?; unregister_push_device(&device.push_uuid).await?;
} }
@ -1412,17 +1413,17 @@ struct AuthRequestRequest {
async fn post_auth_request( async fn post_auth_request(
data: Json<AuthRequestRequest>, data: Json<AuthRequestRequest>,
client_headers: ClientHeaders, client_headers: ClientHeaders,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data = data.into_inner(); let data = data.into_inner();
let Some(user) = User::find_by_mail(&data.email, &mut conn).await else { let Some(user) = User::find_by_mail(&data.email, &conn).await else {
err!("AuthRequest doesn't exist", "User not found") err!("AuthRequest doesn't exist", "User not found")
}; };
// Validate device uuid and type // Validate device uuid and type
let device = match Device::find_by_uuid_and_user(&data.device_identifier, &user.uuid, &mut conn).await { let device = match Device::find_by_uuid_and_user(&data.device_identifier, &user.uuid, &conn).await {
Some(device) if device.atype == client_headers.device_type => device, Some(device) if device.atype == client_headers.device_type => device,
_ => err!("AuthRequest doesn't exist", "Device verification failed"), _ => err!("AuthRequest doesn't exist", "Device verification failed"),
}; };
@ -1435,16 +1436,16 @@ async fn post_auth_request(
data.access_code, data.access_code,
data.public_key, data.public_key,
); );
auth_request.save(&mut conn).await?; auth_request.save(&conn).await?;
nt.send_auth_request(&user.uuid, &auth_request.uuid, &device, &mut conn).await; nt.send_auth_request(&user.uuid, &auth_request.uuid, &device, &conn).await;
log_user_event( log_user_event(
EventType::UserRequestedDeviceApproval as i32, EventType::UserRequestedDeviceApproval as i32,
&user.uuid, &user.uuid,
client_headers.device_type, client_headers.device_type,
&client_headers.ip.ip, &client_headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
@ -1464,8 +1465,8 @@ async fn post_auth_request(
} }
#[get("/auth-requests/<auth_request_id>")] #[get("/auth-requests/<auth_request_id>")]
async fn get_auth_request(auth_request_id: AuthRequestId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_auth_request(auth_request_id: AuthRequestId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await let Some(auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &conn).await
else { else {
err!("AuthRequest doesn't exist", "Record not found or user uuid does not match") err!("AuthRequest doesn't exist", "Record not found or user uuid does not match")
}; };
@ -1501,13 +1502,12 @@ async fn put_auth_request(
auth_request_id: AuthRequestId, auth_request_id: AuthRequestId,
data: Json<AuthResponseRequest>, data: Json<AuthResponseRequest>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
ant: AnonymousNotify<'_>, ant: AnonymousNotify<'_>,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data = data.into_inner(); let data = data.into_inner();
let Some(mut auth_request) = let Some(mut auth_request) = AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &conn).await
AuthRequest::find_by_uuid_and_user(&auth_request_id, &headers.user.uuid, &mut conn).await
else { else {
err!("AuthRequest doesn't exist", "Record not found or user uuid does not match") err!("AuthRequest doesn't exist", "Record not found or user uuid does not match")
}; };
@ -1529,28 +1529,28 @@ async fn put_auth_request(
auth_request.master_password_hash = data.master_password_hash; auth_request.master_password_hash = data.master_password_hash;
auth_request.response_device_id = Some(data.device_identifier.clone()); auth_request.response_device_id = Some(data.device_identifier.clone());
auth_request.response_date = Some(response_date); auth_request.response_date = Some(response_date);
auth_request.save(&mut conn).await?; auth_request.save(&conn).await?;
ant.send_auth_response(&auth_request.user_uuid, &auth_request.uuid).await; ant.send_auth_response(&auth_request.user_uuid, &auth_request.uuid).await;
nt.send_auth_response(&auth_request.user_uuid, &auth_request.uuid, &headers.device, &mut conn).await; nt.send_auth_response(&auth_request.user_uuid, &auth_request.uuid, &headers.device, &conn).await;
log_user_event( log_user_event(
EventType::OrganizationUserApprovedAuthRequest as i32, EventType::OrganizationUserApprovedAuthRequest as i32,
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} else { } else {
// If denied, there's no reason to keep the request // If denied, there's no reason to keep the request
auth_request.delete(&mut conn).await?; auth_request.delete(&conn).await?;
log_user_event( log_user_event(
EventType::OrganizationUserRejectedAuthRequest as i32, EventType::OrganizationUserRejectedAuthRequest as i32,
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -1575,9 +1575,9 @@ async fn get_auth_request_response(
auth_request_id: AuthRequestId, auth_request_id: AuthRequestId,
code: &str, code: &str,
client_headers: ClientHeaders, client_headers: ClientHeaders,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let Some(auth_request) = AuthRequest::find_by_uuid(&auth_request_id, &mut conn).await else { let Some(auth_request) = AuthRequest::find_by_uuid(&auth_request_id, &conn).await else {
err!("AuthRequest doesn't exist", "User not found") err!("AuthRequest doesn't exist", "User not found")
}; };
@ -1606,8 +1606,8 @@ async fn get_auth_request_response(
} }
#[get("/auth-requests")] #[get("/auth-requests")]
async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_auth_requests(headers: Headers, conn: DbConn) -> JsonResult {
let auth_requests = AuthRequest::find_by_user(&headers.user.uuid, &mut conn).await; let auth_requests = AuthRequest::find_by_user(&headers.user.uuid, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"data": auth_requests "data": auth_requests
@ -1637,8 +1637,8 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
pub async fn purge_auth_requests(pool: DbPool) { pub async fn purge_auth_requests(pool: DbPool) {
debug!("Purging auth requests"); debug!("Purging auth requests");
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await; AuthRequest::purge_expired_auth_requests(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging trashed ciphers") error!("Failed to get DB connection while purging trashed ciphers")
} }

345
src/api/core/ciphers.rs

@ -17,7 +17,14 @@ use crate::{
auth::Headers, auth::Headers,
config::PathType, config::PathType,
crypto, crypto,
db::{models::*, DbConn, DbPool}, db::{
models::{
Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId,
CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, MembershipType,
OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
},
DbConn, DbPool,
},
CONFIG, CONFIG,
}; };
@ -93,8 +100,8 @@ pub fn routes() -> Vec<Route> {
pub async fn purge_trashed_ciphers(pool: DbPool) { pub async fn purge_trashed_ciphers(pool: DbPool) {
debug!("Purging trashed ciphers"); debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&mut conn).await; Cipher::purge_trash(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging trashed ciphers") error!("Failed to get DB connection while purging trashed ciphers")
} }
@ -107,11 +114,11 @@ struct SyncData {
} }
#[get("/sync?<data..>")] #[get("/sync?<data..>")]
async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVersion>, mut conn: DbConn) -> JsonResult { async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVersion>, conn: DbConn) -> JsonResult {
let user_json = headers.user.to_json(&mut conn).await; let user_json = headers.user.to_json(&conn).await;
// Get all ciphers which are visible by the user // Get all ciphers which are visible by the user
let mut ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &mut conn).await; let mut ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &conn).await;
// Filter out SSH keys if the client version is less than 2024.12.0 // Filter out SSH keys if the client version is less than 2024.12.0
let show_ssh_keys = if let Some(client_version) = client_version { let show_ssh_keys = if let Some(client_version) = client_version {
@ -124,31 +131,30 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
ciphers.retain(|c| c.atype != 5); ciphers.retain(|c| c.atype != 5);
} }
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &mut conn).await; let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &conn).await;
// Lets generate the ciphers_json using all the gathered info // Lets generate the ciphers_json using all the gathered info
let mut ciphers_json = Vec::with_capacity(ciphers.len()); let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers { for c in ciphers {
ciphers_json.push( ciphers_json.push(
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &conn).await?,
.await?,
); );
} }
let collections = Collection::find_by_user_uuid(headers.user.uuid.clone(), &mut conn).await; let collections = Collection::find_by_user_uuid(headers.user.uuid.clone(), &conn).await;
let mut collections_json = Vec::with_capacity(collections.len()); let mut collections_json = Vec::with_capacity(collections.len());
for c in collections { for c in collections {
collections_json.push(c.to_json_details(&headers.user.uuid, Some(&cipher_sync_data), &mut conn).await); collections_json.push(c.to_json_details(&headers.user.uuid, Some(&cipher_sync_data), &conn).await);
} }
let folders_json: Vec<Value> = let folders_json: Vec<Value> =
Folder::find_by_user(&headers.user.uuid, &mut conn).await.iter().map(Folder::to_json).collect(); Folder::find_by_user(&headers.user.uuid, &conn).await.iter().map(Folder::to_json).collect();
let sends_json: Vec<Value> = let sends_json: Vec<Value> =
Send::find_by_user(&headers.user.uuid, &mut conn).await.iter().map(Send::to_json).collect(); Send::find_by_user(&headers.user.uuid, &conn).await.iter().map(Send::to_json).collect();
let policies_json: Vec<Value> = let policies_json: Vec<Value> =
OrgPolicy::find_confirmed_by_user(&headers.user.uuid, &mut conn).await.iter().map(OrgPolicy::to_json).collect(); OrgPolicy::find_confirmed_by_user(&headers.user.uuid, &conn).await.iter().map(OrgPolicy::to_json).collect();
let domains_json = if data.exclude_domains { let domains_json = if data.exclude_domains {
Value::Null Value::Null
@ -169,15 +175,14 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
} }
#[get("/ciphers")] #[get("/ciphers")]
async fn get_ciphers(headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_ciphers(headers: Headers, conn: DbConn) -> JsonResult {
let ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &mut conn).await; let ciphers = Cipher::find_by_user_visible(&headers.user.uuid, &conn).await;
let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &mut conn).await; let cipher_sync_data = CipherSyncData::new(&headers.user.uuid, CipherSyncType::User, &conn).await;
let mut ciphers_json = Vec::with_capacity(ciphers.len()); let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers { for c in ciphers {
ciphers_json.push( ciphers_json.push(
c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &conn).await?,
.await?,
); );
} }
@ -189,16 +194,16 @@ async fn get_ciphers(headers: Headers, mut conn: DbConn) -> JsonResult {
} }
#[get("/ciphers/<cipher_id>")] #[get("/ciphers/<cipher_id>")]
async fn get_cipher(cipher_id: CipherId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not owned by user") err!("Cipher is not owned by user")
} }
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
#[get("/ciphers/<cipher_id>/admin")] #[get("/ciphers/<cipher_id>/admin")]
@ -291,7 +296,7 @@ async fn post_ciphers_admin(data: Json<ShareCipherData>, headers: Headers, conn:
async fn post_ciphers_create( async fn post_ciphers_create(
data: Json<ShareCipherData>, data: Json<ShareCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let mut data: ShareCipherData = data.into_inner(); let mut data: ShareCipherData = data.into_inner();
@ -305,11 +310,11 @@ async fn post_ciphers_create(
// This check is usually only needed in update_cipher_from_data(), but we // This check is usually only needed in update_cipher_from_data(), but we
// need it here as well to avoid creating an empty cipher in the call to // need it here as well to avoid creating an empty cipher in the call to
// cipher.save() below. // cipher.save() below.
enforce_personal_ownership_policy(Some(&data.cipher), &headers, &mut conn).await?; enforce_personal_ownership_policy(Some(&data.cipher), &headers, &conn).await?;
let mut cipher = Cipher::new(data.cipher.r#type, data.cipher.name.clone()); let mut cipher = Cipher::new(data.cipher.r#type, data.cipher.name.clone());
cipher.user_uuid = Some(headers.user.uuid.clone()); cipher.user_uuid = Some(headers.user.uuid.clone());
cipher.save(&mut conn).await?; cipher.save(&conn).await?;
// When cloning a cipher, the Bitwarden clients seem to set this field // When cloning a cipher, the Bitwarden clients seem to set this field
// based on the cipher being cloned (when creating a new cipher, it's set // based on the cipher being cloned (when creating a new cipher, it's set
@ -319,12 +324,12 @@ async fn post_ciphers_create(
// or otherwise), we can just ignore this field entirely. // or otherwise), we can just ignore this field entirely.
data.cipher.last_known_revision_date = None; data.cipher.last_known_revision_date = None;
share_cipher_by_uuid(&cipher.uuid, data, &headers, &mut conn, &nt, None).await share_cipher_by_uuid(&cipher.uuid, data, &headers, &conn, &nt, None).await
} }
/// Called when creating a new user-owned cipher. /// Called when creating a new user-owned cipher.
#[post("/ciphers", data = "<data>")] #[post("/ciphers", data = "<data>")]
async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn post_ciphers(data: Json<CipherData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let mut data: CipherData = data.into_inner(); let mut data: CipherData = data.into_inner();
// The web/browser clients set this field to null as expected, but the // The web/browser clients set this field to null as expected, but the
@ -334,9 +339,9 @@ async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn
data.last_known_revision_date = None; data.last_known_revision_date = None;
let mut cipher = Cipher::new(data.r#type, data.name.clone()); let mut cipher = Cipher::new(data.r#type, data.name.clone());
update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherCreate).await?; update_cipher_from_data(&mut cipher, data, &headers, None, &conn, &nt, UpdateType::SyncCipherCreate).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
/// Enforces the personal ownership policy on user-owned ciphers, if applicable. /// Enforces the personal ownership policy on user-owned ciphers, if applicable.
@ -346,11 +351,7 @@ async fn post_ciphers(data: Json<CipherData>, headers: Headers, mut conn: DbConn
/// allowed to delete or share such ciphers to an org, however. /// allowed to delete or share such ciphers to an org, however.
/// ///
/// Ref: https://bitwarden.com/help/article/policies/#personal-ownership /// Ref: https://bitwarden.com/help/article/policies/#personal-ownership
async fn enforce_personal_ownership_policy( async fn enforce_personal_ownership_policy(data: Option<&CipherData>, headers: &Headers, conn: &DbConn) -> EmptyResult {
data: Option<&CipherData>,
headers: &Headers,
conn: &mut DbConn,
) -> EmptyResult {
if data.is_none() || data.unwrap().organization_id.is_none() { if data.is_none() || data.unwrap().organization_id.is_none() {
let user_id = &headers.user.uuid; let user_id = &headers.user.uuid;
let policy_type = OrgPolicyType::PersonalOwnership; let policy_type = OrgPolicyType::PersonalOwnership;
@ -366,7 +367,7 @@ pub async fn update_cipher_from_data(
data: CipherData, data: CipherData,
headers: &Headers, headers: &Headers,
shared_to_collections: Option<Vec<CollectionId>>, shared_to_collections: Option<Vec<CollectionId>>,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
ut: UpdateType, ut: UpdateType,
) -> EmptyResult { ) -> EmptyResult {
@ -559,13 +560,8 @@ struct RelationsData {
} }
#[post("/ciphers/import", data = "<data>")] #[post("/ciphers/import", data = "<data>")]
async fn post_ciphers_import( async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
data: Json<ImportData>, enforce_personal_ownership_policy(None, &headers, &conn).await?;
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
enforce_personal_ownership_policy(None, &headers, &mut conn).await?;
let data: ImportData = data.into_inner(); let data: ImportData = data.into_inner();
@ -577,14 +573,14 @@ async fn post_ciphers_import(
// Read and create the folders // Read and create the folders
let existing_folders: HashSet<Option<FolderId>> = let existing_folders: HashSet<Option<FolderId>> =
Folder::find_by_user(&headers.user.uuid, &mut conn).await.into_iter().map(|f| Some(f.uuid)).collect(); Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect();
let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len()); let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len());
for folder in data.folders.into_iter() { for folder in data.folders.into_iter() {
let folder_id = if existing_folders.contains(&folder.id) { let folder_id = if existing_folders.contains(&folder.id) {
folder.id.unwrap() folder.id.unwrap()
} else { } else {
let mut new_folder = Folder::new(headers.user.uuid.clone(), folder.name); let mut new_folder = Folder::new(headers.user.uuid.clone(), folder.name);
new_folder.save(&mut conn).await?; new_folder.save(&conn).await?;
new_folder.uuid new_folder.uuid
}; };
@ -604,12 +600,12 @@ async fn post_ciphers_import(
cipher_data.folder_id = folder_id; cipher_data.folder_id = folder_id;
let mut cipher = Cipher::new(cipher_data.r#type, cipher_data.name.clone()); let mut cipher = Cipher::new(cipher_data.r#type, cipher_data.name.clone());
update_cipher_from_data(&mut cipher, cipher_data, &headers, None, &mut conn, &nt, UpdateType::None).await?; update_cipher_from_data(&mut cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?;
} }
let mut user = headers.user; let mut user = headers.user;
user.update_revision(&mut conn).await?; user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(()) Ok(())
} }
@ -653,12 +649,12 @@ async fn put_cipher(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<CipherData>, data: Json<CipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data: CipherData = data.into_inner(); let data: CipherData = data.into_inner();
let Some(mut cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(mut cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
@ -667,13 +663,13 @@ async fn put_cipher(
// cipher itself, so the user shouldn't need write access to change these. // cipher itself, so the user shouldn't need write access to change these.
// Interestingly, upstream Bitwarden doesn't properly handle this either. // Interestingly, upstream Bitwarden doesn't properly handle this either.
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherUpdate).await?; update_cipher_from_data(&mut cipher, data, &headers, None, &conn, &nt, UpdateType::SyncCipherUpdate).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
#[post("/ciphers/<cipher_id>/partial", data = "<data>")] #[post("/ciphers/<cipher_id>/partial", data = "<data>")]
@ -692,26 +688,26 @@ async fn put_cipher_partial(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<PartialCipherData>, data: Json<PartialCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let data: PartialCipherData = data.into_inner(); let data: PartialCipherData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if let Some(ref folder_id) = data.folder_id { if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &mut conn).await.is_none() { if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user"); err!("Invalid folder", "Folder does not exist or belongs to another user");
} }
} }
// Move cipher // Move cipher
cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &mut conn).await?; cipher.move_to_folder(data.folder_id.clone(), &headers.user.uuid, &conn).await?;
// Update favorite // Update favorite
cipher.set_favorite(Some(data.favorite), &headers.user.uuid, &mut conn).await?; cipher.set_favorite(Some(data.favorite), &headers.user.uuid, &conn).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -764,35 +760,34 @@ async fn post_collections_update(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<CollectionsAdminData>, data: Json<CollectionsAdminData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data: CollectionsAdminData = data.into_inner(); let data: CollectionsAdminData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids); let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections = let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &mut conn).await); HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) { for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
{
None => err!("Invalid collection ID provided"), None => err!("Invalid collection ID provided"),
Some(collection) => { Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &mut conn).await { if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
if posted_collections.contains(&collection.uuid) { if posted_collections.contains(&collection.uuid) {
// Add to collection // Add to collection
CollectionCipher::save(&cipher.uuid, &collection.uuid, &mut conn).await?; CollectionCipher::save(&cipher.uuid, &collection.uuid, &conn).await?;
} else { } else {
// Remove from collection // Remove from collection
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &mut conn).await?; CollectionCipher::delete(&cipher.uuid, &collection.uuid, &conn).await?;
} }
} else { } else {
err!("No rights to modify the collection") err!("No rights to modify the collection")
@ -804,10 +799,10 @@ async fn post_collections_update(
nt.send_cipher_update( nt.send_cipher_update(
UpdateType::SyncCipherUpdate, UpdateType::SyncCipherUpdate,
&cipher, &cipher,
&cipher.update_users_revision(&mut conn).await, &cipher.update_users_revision(&conn).await,
&headers.device, &headers.device,
Some(Vec::from_iter(posted_collections)), Some(Vec::from_iter(posted_collections)),
&mut conn, &conn,
) )
.await; .await;
@ -818,11 +813,11 @@ async fn post_collections_update(
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
#[put("/ciphers/<cipher_id>/collections-admin", data = "<data>")] #[put("/ciphers/<cipher_id>/collections-admin", data = "<data>")]
@ -841,35 +836,34 @@ async fn post_collections_admin(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<CollectionsAdminData>, data: Json<CollectionsAdminData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let data: CollectionsAdminData = data.into_inner(); let data: CollectionsAdminData = data.into_inner();
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids); let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections = let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &mut conn).await); HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) { for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &mut conn).await match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
{
None => err!("Invalid collection ID provided"), None => err!("Invalid collection ID provided"),
Some(collection) => { Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &mut conn).await { if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
if posted_collections.contains(&collection.uuid) { if posted_collections.contains(&collection.uuid) {
// Add to collection // Add to collection
CollectionCipher::save(&cipher.uuid, &collection.uuid, &mut conn).await?; CollectionCipher::save(&cipher.uuid, &collection.uuid, &conn).await?;
} else { } else {
// Remove from collection // Remove from collection
CollectionCipher::delete(&cipher.uuid, &collection.uuid, &mut conn).await?; CollectionCipher::delete(&cipher.uuid, &collection.uuid, &conn).await?;
} }
} else { } else {
err!("No rights to modify the collection") err!("No rights to modify the collection")
@ -881,10 +875,10 @@ async fn post_collections_admin(
nt.send_cipher_update( nt.send_cipher_update(
UpdateType::SyncCipherUpdate, UpdateType::SyncCipherUpdate,
&cipher, &cipher,
&cipher.update_users_revision(&mut conn).await, &cipher.update_users_revision(&conn).await,
&headers.device, &headers.device,
Some(Vec::from_iter(posted_collections)), Some(Vec::from_iter(posted_collections)),
&mut conn, &conn,
) )
.await; .await;
@ -895,7 +889,7 @@ async fn post_collections_admin(
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
@ -916,12 +910,12 @@ async fn post_cipher_share(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<ShareCipherData>, data: Json<ShareCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data: ShareCipherData = data.into_inner(); let data: ShareCipherData = data.into_inner();
share_cipher_by_uuid(&cipher_id, data, &headers, &mut conn, &nt, None).await share_cipher_by_uuid(&cipher_id, data, &headers, &conn, &nt, None).await
} }
#[put("/ciphers/<cipher_id>/share", data = "<data>")] #[put("/ciphers/<cipher_id>/share", data = "<data>")]
@ -929,12 +923,12 @@ async fn put_cipher_share(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<ShareCipherData>, data: Json<ShareCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data: ShareCipherData = data.into_inner(); let data: ShareCipherData = data.into_inner();
share_cipher_by_uuid(&cipher_id, data, &headers, &mut conn, &nt, None).await share_cipher_by_uuid(&cipher_id, data, &headers, &conn, &nt, None).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -948,7 +942,7 @@ struct ShareSelectedCipherData {
async fn put_cipher_share_selected( async fn put_cipher_share_selected(
data: Json<ShareSelectedCipherData>, data: Json<ShareSelectedCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let mut data: ShareSelectedCipherData = data.into_inner(); let mut data: ShareSelectedCipherData = data.into_inner();
@ -975,14 +969,14 @@ async fn put_cipher_share_selected(
match shared_cipher_data.cipher.id.take() { match shared_cipher_data.cipher.id.take() {
Some(id) => { Some(id) => {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &mut conn, &nt, Some(UpdateType::None)).await? share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
} }
None => err!("Request missing ids field"), None => err!("Request missing ids field"),
}; };
} }
// Multi share actions do not send out a push for each cipher, we need to send a general sync here // Multi share actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(()) Ok(())
} }
@ -991,7 +985,7 @@ async fn share_cipher_by_uuid(
cipher_id: &CipherId, cipher_id: &CipherId,
data: ShareCipherData, data: ShareCipherData,
headers: &Headers, headers: &Headers,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
override_ut: Option<UpdateType>, override_ut: Option<UpdateType>,
) -> JsonResult { ) -> JsonResult {
@ -1050,17 +1044,17 @@ async fn get_attachment(
cipher_id: CipherId, cipher_id: CipherId,
attachment_id: AttachmentId, attachment_id: AttachmentId,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not accessible") err!("Cipher is not accessible")
} }
match Attachment::find_by_id(&attachment_id, &mut conn).await { match Attachment::find_by_id(&attachment_id, &conn).await {
Some(attachment) if cipher_id == attachment.cipher_uuid => Ok(Json(attachment.to_json(&headers.host).await?)), Some(attachment) if cipher_id == attachment.cipher_uuid => Ok(Json(attachment.to_json(&headers.host).await?)),
Some(_) => err!("Attachment doesn't belong to cipher"), Some(_) => err!("Attachment doesn't belong to cipher"),
None => err!("Attachment doesn't exist"), None => err!("Attachment doesn't exist"),
@ -1090,13 +1084,13 @@ async fn post_attachment_v2(
cipher_id: CipherId, cipher_id: CipherId,
data: Json<AttachmentRequestData>, data: Json<AttachmentRequestData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
@ -1109,7 +1103,7 @@ async fn post_attachment_v2(
let attachment_id = crypto::generate_attachment_id(); let attachment_id = crypto::generate_attachment_id();
let attachment = let attachment =
Attachment::new(attachment_id.clone(), cipher.uuid.clone(), data.file_name, file_size, Some(data.key)); Attachment::new(attachment_id.clone(), cipher.uuid.clone(), data.file_name, file_size, Some(data.key));
attachment.save(&mut conn).await.expect("Error saving attachment"); attachment.save(&conn).await.expect("Error saving attachment");
let url = format!("/ciphers/{}/attachment/{attachment_id}", cipher.uuid); let url = format!("/ciphers/{}/attachment/{attachment_id}", cipher.uuid);
let response_key = match data.admin_request { let response_key = match data.admin_request {
@ -1122,7 +1116,7 @@ async fn post_attachment_v2(
"attachmentId": attachment_id, "attachmentId": attachment_id,
"url": url, "url": url,
"fileUploadType": FileUploadType::Direct as i32, "fileUploadType": FileUploadType::Direct as i32,
response_key: cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?, response_key: cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?,
}))) })))
} }
@ -1145,7 +1139,7 @@ async fn save_attachment(
cipher_id: CipherId, cipher_id: CipherId,
data: Form<UploadData<'_>>, data: Form<UploadData<'_>>,
headers: &Headers, headers: &Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> Result<(Cipher, DbConn), crate::error::Error> { ) -> Result<(Cipher, DbConn), crate::error::Error> {
let data = data.into_inner(); let data = data.into_inner();
@ -1157,11 +1151,11 @@ async fn save_attachment(
err!("Attachment size can't be negative") err!("Attachment size can't be negative")
} }
let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &mut conn).await else { let Some(cipher) = Cipher::find_by_uuid(&cipher_id, &conn).await else {
err!("Cipher doesn't exist") err!("Cipher doesn't exist")
}; };
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &mut conn).await { if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
err!("Cipher is not write accessible") err!("Cipher is not write accessible")
} }
@ -1176,7 +1170,7 @@ async fn save_attachment(
match CONFIG.user_attachment_limit() { match CONFIG.user_attachment_limit() {
Some(0) => err!("Attachments are disabled"), Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => { Some(limit_kb) => {
let already_used = Attachment::size_by_user(user_id, &mut conn).await; let already_used = Attachment::size_by_user(user_id, &conn).await;
let left = limit_kb let left = limit_kb
.checked_mul(1024) .checked_mul(1024)
.and_then(|l| l.checked_sub(already_used)) .and_then(|l| l.checked_sub(already_used))
@ -1198,7 +1192,7 @@ async fn save_attachment(
match CONFIG.org_attachment_limit() { match CONFIG.org_attachment_limit() {
Some(0) => err!("Attachments are disabled"), Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => { Some(limit_kb) => {
let already_used = Attachment::size_by_org(org_id, &mut conn).await; let already_used = Attachment::size_by_org(org_id, &conn).await;
let left = limit_kb let left = limit_kb
.checked_mul(1024) .checked_mul(1024)
.and_then(|l| l.checked_sub(already_used)) .and_then(|l| l.checked_sub(already_used))
@ -1249,10 +1243,10 @@ async fn save_attachment(
if size != attachment.file_size { if size != attachment.file_size {
// Update the attachment with the actual file size. // Update the attachment with the actual file size.
attachment.file_size = size; attachment.file_size = size;
attachment.save(&mut conn).await.expect("Error updating attachment"); attachment.save(&conn).await.expect("Error updating attachment");
} }
} else { } else {
attachment.delete(&mut conn).await.ok(); attachment.delete(&conn).await.ok();
err!(format!("Attachment size mismatch (expected within [{min_size}, {max_size}], got {size})")); err!(format!("Attachment size mismatch (expected within [{min_size}, {max_size}], got {size})"));
} }
@ -1272,7 +1266,7 @@ async fn save_attachment(
} }
let attachment = let attachment =
Attachment::new(file_id.clone(), cipher_id.clone(), encrypted_filename.unwrap(), size, data.key); Attachment::new(file_id.clone(), cipher_id.clone(), encrypted_filename.unwrap(), size, data.key);
attachment.save(&mut conn).await.expect("Error saving attachment"); attachment.save(&conn).await.expect("Error saving attachment");
} }
save_temp_file(PathType::Attachments, &format!("{cipher_id}/{file_id}"), data.data, true).await?; save_temp_file(PathType::Attachments, &format!("{cipher_id}/{file_id}"), data.data, true).await?;
@ -1280,10 +1274,10 @@ async fn save_attachment(
nt.send_cipher_update( nt.send_cipher_update(
UpdateType::SyncCipherUpdate, UpdateType::SyncCipherUpdate,
&cipher, &cipher,
&cipher.update_users_revision(&mut conn).await, &cipher.update_users_revision(&conn).await,
&headers.device, &headers.device,
None, None,
&mut conn, &conn,
) )
.await; .await;
@ -1295,7 +1289,7 @@ async fn save_attachment(
&headers.user.uuid, &headers.user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -1313,10 +1307,10 @@ async fn post_attachment_v2_data(
attachment_id: AttachmentId, attachment_id: AttachmentId,
data: Form<UploadData<'_>>, data: Form<UploadData<'_>>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let attachment = match Attachment::find_by_id(&attachment_id, &mut conn).await { let attachment = match Attachment::find_by_id(&attachment_id, &conn).await {
Some(attachment) if cipher_id == attachment.cipher_uuid => Some(attachment), Some(attachment) if cipher_id == attachment.cipher_uuid => Some(attachment),
Some(_) => err!("Attachment doesn't belong to cipher"), Some(_) => err!("Attachment doesn't belong to cipher"),
None => err!("Attachment doesn't exist"), None => err!("Attachment doesn't exist"),
@ -1340,9 +1334,9 @@ async fn post_attachment(
// the attachment database record as well as saving the data to disk. // the attachment database record as well as saving the data to disk.
let attachment = None; let attachment = None;
let (cipher, mut conn) = save_attachment(attachment, cipher_id, data, &headers, conn, nt).await?; let (cipher, conn) = save_attachment(attachment, cipher_id, data, &headers, conn, nt).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await?)) Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &conn).await?))
} }
#[post("/ciphers/<cipher_id>/attachment-admin", format = "multipart/form-data", data = "<data>")] #[post("/ciphers/<cipher_id>/attachment-admin", format = "multipart/form-data", data = "<data>")]
@ -1362,10 +1356,10 @@ async fn post_attachment_share(
attachment_id: AttachmentId, attachment_id: AttachmentId,
data: Form<UploadData<'_>>, data: Form<UploadData<'_>>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await?; _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
post_attachment(cipher_id, data, headers, conn, nt).await post_attachment(cipher_id, data, headers, conn, nt).await
} }
@ -1396,10 +1390,10 @@ async fn delete_attachment(
cipher_id: CipherId, cipher_id: CipherId,
attachment_id: AttachmentId, attachment_id: AttachmentId,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
} }
#[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")] #[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")]
@ -1407,55 +1401,45 @@ async fn delete_attachment_admin(
cipher_id: CipherId, cipher_id: CipherId,
attachment_id: AttachmentId, attachment_id: AttachmentId,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &mut conn, &nt).await _delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
} }
#[post("/ciphers/<cipher_id>/delete")] #[post("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[post("/ciphers/<cipher_id>/delete-admin")] #[post("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_post_admin( async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
cipher_id: CipherId, _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[put("/ciphers/<cipher_id>/delete")] #[put("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::SoftSingle, &nt).await _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete // soft delete
} }
#[put("/ciphers/<cipher_id>/delete-admin")] #[put("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_put_admin( async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
cipher_id: CipherId, _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete // soft delete
} }
#[delete("/ciphers/<cipher_id>")] #[delete("/ciphers/<cipher_id>")]
async fn delete_cipher(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
#[delete("/ciphers/<cipher_id>/admin")] #[delete("/ciphers/<cipher_id>/admin")]
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &CipherDeleteOptions::HardSingle, &nt).await _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete // permanent delete
} }
@ -1526,38 +1510,33 @@ async fn delete_cipher_selected_put_admin(
} }
#[put("/ciphers/<cipher_id>/restore")] #[put("/ciphers/<cipher_id>/restore")]
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &mut conn, &nt).await _restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
} }
#[put("/ciphers/<cipher_id>/restore-admin")] #[put("/ciphers/<cipher_id>/restore-admin")]
async fn restore_cipher_put_admin( async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
cipher_id: CipherId, _restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &mut conn, &nt).await
} }
#[put("/ciphers/restore-admin", data = "<data>")] #[put("/ciphers/restore-admin", data = "<data>")]
async fn restore_cipher_selected_admin( async fn restore_cipher_selected_admin(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &mut conn, &nt).await _restore_multiple_ciphers(data, &headers, &conn, &nt).await
} }
#[put("/ciphers/restore", data = "<data>")] #[put("/ciphers/restore", data = "<data>")]
async fn restore_cipher_selected( async fn restore_cipher_selected(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &mut conn, &nt).await _restore_multiple_ciphers(data, &headers, &conn, &nt).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -1571,14 +1550,14 @@ struct MoveCipherData {
async fn move_cipher_selected( async fn move_cipher_selected(
data: Json<MoveCipherData>, data: Json<MoveCipherData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let data = data.into_inner(); let data = data.into_inner();
let user_id = &headers.user.uuid; let user_id = &headers.user.uuid;
if let Some(ref folder_id) = data.folder_id { if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, user_id, &mut conn).await.is_none() { if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user"); err!("Invalid folder", "Folder does not exist or belongs to another user");
} }
} }
@ -1588,10 +1567,10 @@ async fn move_cipher_selected(
// TODO: Convert this to use a single query (or at least less) to update all items // TODO: Convert this to use a single query (or at least less) to update all items
// Find all ciphers a user has access to, all others will be ignored // Find all ciphers a user has access to, all others will be ignored
let accessible_ciphers = Cipher::find_by_user_and_ciphers(user_id, &data.ids, &mut conn).await; let accessible_ciphers = Cipher::find_by_user_and_ciphers(user_id, &data.ids, &conn).await;
let accessible_ciphers_count = accessible_ciphers.len(); let accessible_ciphers_count = accessible_ciphers.len();
for cipher in accessible_ciphers { for cipher in accessible_ciphers {
cipher.move_to_folder(data.folder_id.clone(), user_id, &mut conn).await?; cipher.move_to_folder(data.folder_id.clone(), user_id, &conn).await?;
if cipher_count == 1 { if cipher_count == 1 {
single_cipher = Some(cipher); single_cipher = Some(cipher);
} }
@ -1604,12 +1583,12 @@ async fn move_cipher_selected(
std::slice::from_ref(user_id), std::slice::from_ref(user_id),
&headers.device, &headers.device,
None, None,
&mut conn, &conn,
) )
.await; .await;
} else { } else {
// Multi move actions do not send out a push for each cipher, we need to send a general sync here // Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
} }
if cipher_count != accessible_ciphers_count { if cipher_count != accessible_ciphers_count {
@ -1642,23 +1621,23 @@ async fn delete_all(
organization: Option<OrganizationIdData>, organization: Option<OrganizationIdData>,
data: Json<PasswordOrOtpData>, data: Json<PasswordOrOtpData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
data.validate(&user, true, &mut conn).await?; data.validate(&user, true, &conn).await?;
match organization { match organization {
Some(org_data) => { Some(org_data) => {
// Organization ID in query params, purging organization vault // Organization ID in query params, purging organization vault
match Membership::find_by_user_and_org(&user.uuid, &org_data.org_id, &mut conn).await { match Membership::find_by_user_and_org(&user.uuid, &org_data.org_id, &conn).await {
None => err!("You don't have permission to purge the organization vault"), None => err!("You don't have permission to purge the organization vault"),
Some(member) => { Some(member) => {
if member.atype == MembershipType::Owner { if member.atype == MembershipType::Owner {
Cipher::delete_all_by_organization(&org_data.org_id, &mut conn).await?; Cipher::delete_all_by_organization(&org_data.org_id, &conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
log_event( log_event(
EventType::OrganizationPurgedVault as i32, EventType::OrganizationPurgedVault as i32,
@ -1667,7 +1646,7 @@ async fn delete_all(
&user.uuid, &user.uuid,
headers.device.atype, headers.device.atype,
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
@ -1681,17 +1660,17 @@ async fn delete_all(
None => { None => {
// No organization ID in query params, purging user vault // No organization ID in query params, purging user vault
// Delete ciphers and their attachments // Delete ciphers and their attachments
for cipher in Cipher::find_owned_by_user(&user.uuid, &mut conn).await { for cipher in Cipher::find_owned_by_user(&user.uuid, &conn).await {
cipher.delete(&mut conn).await?; cipher.delete(&conn).await?;
} }
// Delete folders // Delete folders
for f in Folder::find_by_user(&user.uuid, &mut conn).await { for f in Folder::find_by_user(&user.uuid, &conn).await {
f.delete(&mut conn).await?; f.delete(&conn).await?;
} }
user.update_revision(&mut conn).await?; user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(()) Ok(())
} }
@ -1709,7 +1688,7 @@ pub enum CipherDeleteOptions {
async fn _delete_cipher_by_uuid( async fn _delete_cipher_by_uuid(
cipher_id: &CipherId, cipher_id: &CipherId,
headers: &Headers, headers: &Headers,
conn: &mut DbConn, conn: &DbConn,
delete_options: &CipherDeleteOptions, delete_options: &CipherDeleteOptions,
nt: &Notify<'_>, nt: &Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
@ -1775,20 +1754,20 @@ struct CipherIdsData {
async fn _delete_multiple_ciphers( async fn _delete_multiple_ciphers(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
delete_options: CipherDeleteOptions, delete_options: CipherDeleteOptions,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
let data = data.into_inner(); let data = data.into_inner();
for cipher_id in data.ids { for cipher_id in data.ids {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &mut conn, &delete_options, &nt).await { if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
return error; return error;
}; };
} }
// Multi delete actions do not send out a push for each cipher, we need to send a general sync here // Multi delete actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(()) Ok(())
} }
@ -1797,7 +1776,7 @@ async fn _restore_cipher_by_uuid(
cipher_id: &CipherId, cipher_id: &CipherId,
headers: &Headers, headers: &Headers,
multi_restore: bool, multi_restore: bool,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let Some(mut cipher) = Cipher::find_by_uuid(cipher_id, conn).await else { let Some(mut cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
@ -1842,7 +1821,7 @@ async fn _restore_cipher_by_uuid(
async fn _restore_multiple_ciphers( async fn _restore_multiple_ciphers(
data: Json<CipherIdsData>, data: Json<CipherIdsData>,
headers: &Headers, headers: &Headers,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data = data.into_inner(); let data = data.into_inner();
@ -1869,7 +1848,7 @@ async fn _delete_cipher_attachment_by_id(
cipher_id: &CipherId, cipher_id: &CipherId,
attachment_id: &AttachmentId, attachment_id: &AttachmentId,
headers: &Headers, headers: &Headers,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let Some(attachment) = Attachment::find_by_id(attachment_id, conn).await else { let Some(attachment) = Attachment::find_by_id(attachment_id, conn).await else {
@ -1938,7 +1917,7 @@ pub enum CipherSyncType {
} }
impl CipherSyncData { impl CipherSyncData {
pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &mut DbConn) -> Self { pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &DbConn) -> Self {
let cipher_folders: HashMap<CipherId, FolderId>; let cipher_folders: HashMap<CipherId, FolderId>;
let cipher_favorites: HashSet<CipherId>; let cipher_favorites: HashSet<CipherId>;
match sync_type { match sync_type {

172
src/api/core/emergency_access.rs

@ -8,7 +8,13 @@ use crate::{
EmptyResult, JsonResult, EmptyResult, JsonResult,
}, },
auth::{decode_emergency_access_invite, Headers}, auth::{decode_emergency_access_invite, Headers},
db::{models::*, DbConn, DbPool}, db::{
models::{
Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation,
Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId,
},
DbConn, DbPool,
},
mail, mail,
util::NumberOrString, util::NumberOrString,
CONFIG, CONFIG,
@ -40,7 +46,7 @@ pub fn routes() -> Vec<Route> {
// region get // region get
#[get("/emergency-access/trusted")] #[get("/emergency-access/trusted")]
async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn get_contacts(headers: Headers, conn: DbConn) -> Json<Value> {
if !CONFIG.emergency_access_allowed() { if !CONFIG.emergency_access_allowed() {
return Json(json!({ return Json(json!({
"data": [{ "data": [{
@ -58,10 +64,10 @@ async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> {
"continuationToken": null "continuationToken": null
})); }));
} }
let emergency_access_list = EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &mut conn).await; let emergency_access_list = EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &conn).await;
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len()); let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list { for ea in emergency_access_list {
if let Some(grantee) = ea.to_json_grantee_details(&mut conn).await { if let Some(grantee) = ea.to_json_grantee_details(&conn).await {
emergency_access_list_json.push(grantee) emergency_access_list_json.push(grantee)
} }
} }
@ -74,15 +80,15 @@ async fn get_contacts(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[get("/emergency-access/granted")] #[get("/emergency-access/granted")]
async fn get_grantees(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn get_grantees(headers: Headers, conn: DbConn) -> Json<Value> {
let emergency_access_list = if CONFIG.emergency_access_allowed() { let emergency_access_list = if CONFIG.emergency_access_allowed() {
EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &mut conn).await EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &conn).await
} else { } else {
Vec::new() Vec::new()
}; };
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len()); let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list { for ea in emergency_access_list {
emergency_access_list_json.push(ea.to_json_grantor_details(&mut conn).await); emergency_access_list_json.push(ea.to_json_grantor_details(&conn).await);
} }
Json(json!({ Json(json!({
@ -93,12 +99,12 @@ async fn get_grantees(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[get("/emergency-access/<emer_id>")] #[get("/emergency-access/<emer_id>")]
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await { match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await {
Some(emergency_access) => Ok(Json( Some(emergency_access) => Ok(Json(
emergency_access.to_json_grantee_details(&mut conn).await.expect("Grantee user should exist but does not!"), emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"),
)), )),
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
} }
@ -131,14 +137,14 @@ async fn post_emergency_access(
emer_id: EmergencyAccessId, emer_id: EmergencyAccessId,
data: Json<EmergencyAccessUpdateData>, data: Json<EmergencyAccessUpdateData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let data: EmergencyAccessUpdateData = data.into_inner(); let data: EmergencyAccessUpdateData = data.into_inner();
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -154,7 +160,7 @@ async fn post_emergency_access(
emergency_access.key_encrypted = data.key_encrypted; emergency_access.key_encrypted = data.key_encrypted;
} }
emergency_access.save(&mut conn).await?; emergency_access.save(&conn).await?;
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} }
@ -163,12 +169,12 @@ async fn post_emergency_access(
// region delete // region delete
#[delete("/emergency-access/<emer_id>")] #[delete("/emergency-access/<emer_id>")]
async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let emergency_access = match ( let emergency_access = match (
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await, EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await,
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &mut conn).await, EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &conn).await,
) { ) {
(Some(grantor_emer), None) => { (Some(grantor_emer), None) => {
info!("Grantor deleted emergency access {emer_id}"); info!("Grantor deleted emergency access {emer_id}");
@ -181,7 +187,7 @@ async fn delete_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
_ => err!("Emergency access not valid."), _ => err!("Emergency access not valid."),
}; };
emergency_access.delete(&mut conn).await?; emergency_access.delete(&conn).await?;
Ok(()) Ok(())
} }
@ -203,7 +209,7 @@ struct EmergencyAccessInviteData {
} }
#[post("/emergency-access/invite", data = "<data>")] #[post("/emergency-access/invite", data = "<data>")]
async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let data: EmergencyAccessInviteData = data.into_inner(); let data: EmergencyAccessInviteData = data.into_inner();
@ -224,7 +230,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
err!("You can not set yourself as an emergency contact.") err!("You can not set yourself as an emergency contact.")
} }
let (grantee_user, new_user) = match User::find_by_mail(&email, &mut conn).await { let (grantee_user, new_user) = match User::find_by_mail(&email, &conn).await {
None => { None => {
if !CONFIG.invitations_allowed() { if !CONFIG.invitations_allowed() {
err!(format!("Grantee user does not exist: {email}")) err!(format!("Grantee user does not exist: {email}"))
@ -236,11 +242,11 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
let invitation = Invitation::new(&email); let invitation = Invitation::new(&email);
invitation.save(&mut conn).await?; invitation.save(&conn).await?;
} }
let mut user = User::new(email.clone(), None); let mut user = User::new(email.clone(), None);
user.save(&mut conn).await?; user.save(&conn).await?;
(user, true) (user, true)
} }
Some(user) if user.password_hash.is_empty() => (user, true), Some(user) if user.password_hash.is_empty() => (user, true),
@ -251,7 +257,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
&grantor_user.uuid, &grantor_user.uuid,
&grantee_user.uuid, &grantee_user.uuid,
&grantee_user.email, &grantee_user.email,
&mut conn, &conn,
) )
.await .await
.is_some() .is_some()
@ -261,7 +267,7 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
let mut new_emergency_access = let mut new_emergency_access =
EmergencyAccess::new(grantor_user.uuid, grantee_user.email, emergency_access_status, new_type, wait_time_days); EmergencyAccess::new(grantor_user.uuid, grantee_user.email, emergency_access_status, new_type, wait_time_days);
new_emergency_access.save(&mut conn).await?; new_emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite( mail::send_emergency_access_invite(
@ -274,18 +280,18 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, mu
.await?; .await?;
} else if !new_user { } else if !new_user {
// if mail is not enabled immediately accept the invitation for existing users // if mail is not enabled immediately accept the invitation for existing users
new_emergency_access.accept_invite(&grantee_user.uuid, &email, &mut conn).await?; new_emergency_access.accept_invite(&grantee_user.uuid, &email, &conn).await?;
} }
Ok(()) Ok(())
} }
#[post("/emergency-access/<emer_id>/reinvite")] #[post("/emergency-access/<emer_id>/reinvite")]
async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -298,7 +304,7 @@ async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: D
err!("Email not valid.") err!("Email not valid.")
}; };
let Some(grantee_user) = User::find_by_mail(&email, &mut conn).await else { let Some(grantee_user) = User::find_by_mail(&email, &conn).await else {
err!("Grantee user not found.") err!("Grantee user not found.")
}; };
@ -315,10 +321,10 @@ async fn resend_invite(emer_id: EmergencyAccessId, headers: Headers, mut conn: D
.await?; .await?;
} else if !grantee_user.password_hash.is_empty() { } else if !grantee_user.password_hash.is_empty() {
// accept the invitation for existing user // accept the invitation for existing user
emergency_access.accept_invite(&grantee_user.uuid, &email, &mut conn).await?; emergency_access.accept_invite(&grantee_user.uuid, &email, &conn).await?;
} else if CONFIG.invitations_allowed() && Invitation::find_by_mail(&email, &mut conn).await.is_none() { } else if CONFIG.invitations_allowed() && Invitation::find_by_mail(&email, &conn).await.is_none() {
let invitation = Invitation::new(&email); let invitation = Invitation::new(&email);
invitation.save(&mut conn).await?; invitation.save(&conn).await?;
} }
Ok(()) Ok(())
@ -335,7 +341,7 @@ async fn accept_invite(
emer_id: EmergencyAccessId, emer_id: EmergencyAccessId,
data: Json<AcceptData>, data: Json<AcceptData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
@ -349,9 +355,9 @@ async fn accept_invite(
err!("Claim email does not match current users email") err!("Claim email does not match current users email")
} }
let grantee_user = match User::find_by_mail(&claims.email, &mut conn).await { let grantee_user = match User::find_by_mail(&claims.email, &conn).await {
Some(user) => { Some(user) => {
Invitation::take(&claims.email, &mut conn).await; Invitation::take(&claims.email, &conn).await;
user user
} }
None => err!("Invited user not found"), None => err!("Invited user not found"),
@ -360,13 +366,13 @@ async fn accept_invite(
// We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database. // We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database.
// The uuid of the grantee gets stored once accepted. // The uuid of the grantee gets stored once accepted.
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_email(&emer_id, &headers.user.email, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_email(&emer_id, &headers.user.email, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
// get grantor user to send Accepted email // get grantor user to send Accepted email
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
@ -374,7 +380,7 @@ async fn accept_invite(
&& grantor_user.name == claims.grantor_name && grantor_user.name == claims.grantor_name
&& grantor_user.email == claims.grantor_email && grantor_user.email == claims.grantor_email
{ {
emergency_access.accept_invite(&grantee_user.uuid, &grantee_user.email, &mut conn).await?; emergency_access.accept_invite(&grantee_user.uuid, &grantee_user.email, &conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_accepted(&grantor_user.email, &grantee_user.email).await?; mail::send_emergency_access_invite_accepted(&grantor_user.email, &grantee_user.email).await?;
@ -397,7 +403,7 @@ async fn confirm_emergency_access(
emer_id: EmergencyAccessId, emer_id: EmergencyAccessId,
data: Json<ConfirmData>, data: Json<ConfirmData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
@ -406,7 +412,7 @@ async fn confirm_emergency_access(
let key = data.key; let key = data.key;
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &confirming_user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &confirming_user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -417,12 +423,12 @@ async fn confirm_emergency_access(
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(grantor_user) = User::find_by_uuid(&confirming_user.uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&confirming_user.uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else { let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.") err!("Grantee user not found.")
}; };
@ -430,7 +436,7 @@ async fn confirm_emergency_access(
emergency_access.key_encrypted = Some(key); emergency_access.key_encrypted = Some(key);
emergency_access.email = None; emergency_access.email = None;
emergency_access.save(&mut conn).await?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_confirmed(&grantee_user.email, &grantor_user.name).await?; mail::send_emergency_access_invite_confirmed(&grantee_user.email, &grantor_user.name).await?;
@ -446,12 +452,12 @@ async fn confirm_emergency_access(
// region access emergency access // region access emergency access
#[post("/emergency-access/<emer_id>/initiate")] #[post("/emergency-access/<emer_id>/initiate")]
async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let initiating_user = headers.user; let initiating_user = headers.user;
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &initiating_user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &initiating_user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -460,7 +466,7 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
@ -469,7 +475,7 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
emergency_access.updated_at = now; emergency_access.updated_at = now;
emergency_access.recovery_initiated_at = Some(now); emergency_access.recovery_initiated_at = Some(now);
emergency_access.last_notification_at = Some(now); emergency_access.last_notification_at = Some(now);
emergency_access.save(&mut conn).await?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_initiated( mail::send_emergency_access_recovery_initiated(
@ -484,11 +490,11 @@ async fn initiate_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
} }
#[post("/emergency-access/<emer_id>/approve")] #[post("/emergency-access/<emer_id>/approve")]
async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -497,17 +503,17 @@ async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(grantor_user) = User::find_by_uuid(&headers.user.uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&headers.user.uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else { let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.") err!("Grantee user not found.")
}; };
emergency_access.status = EmergencyAccessStatus::RecoveryApproved as i32; emergency_access.status = EmergencyAccessStatus::RecoveryApproved as i32;
emergency_access.save(&mut conn).await?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name).await?; mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name).await?;
@ -519,11 +525,11 @@ async fn approve_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
} }
#[post("/emergency-access/<emer_id>/reject")] #[post("/emergency-access/<emer_id>/reject")]
async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let Some(mut emergency_access) = let Some(mut emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -535,12 +541,12 @@ async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
} }
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &mut conn).await else { let Some(grantee_user) = User::find_by_uuid(grantee_uuid, &conn).await else {
err!("Grantee user not found.") err!("Grantee user not found.")
}; };
emergency_access.status = EmergencyAccessStatus::Confirmed as i32; emergency_access.status = EmergencyAccessStatus::Confirmed as i32;
emergency_access.save(&mut conn).await?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_rejected(&grantee_user.email, &headers.user.name).await?; mail::send_emergency_access_recovery_rejected(&grantee_user.email, &headers.user.name).await?;
@ -556,11 +562,11 @@ async fn reject_emergency_access(emer_id: EmergencyAccessId, headers: Headers, m
// region action // region action
#[post("/emergency-access/<emer_id>/view")] #[post("/emergency-access/<emer_id>/view")]
async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let Some(emergency_access) = let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &headers.user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -569,8 +575,8 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &mut conn).await; let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &conn).await;
let cipher_sync_data = CipherSyncData::new(&emergency_access.grantor_uuid, CipherSyncType::User, &mut conn).await; let cipher_sync_data = CipherSyncData::new(&emergency_access.grantor_uuid, CipherSyncType::User, &conn).await;
let mut ciphers_json = Vec::with_capacity(ciphers.len()); let mut ciphers_json = Vec::with_capacity(ciphers.len());
for c in ciphers { for c in ciphers {
@ -580,7 +586,7 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
&emergency_access.grantor_uuid, &emergency_access.grantor_uuid,
Some(&cipher_sync_data), Some(&cipher_sync_data),
CipherSyncType::User, CipherSyncType::User,
&mut conn, &conn,
) )
.await?, .await?,
); );
@ -594,12 +600,12 @@ async fn view_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut
} }
#[post("/emergency-access/<emer_id>/takeover")] #[post("/emergency-access/<emer_id>/takeover")]
async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
let requesting_user = headers.user; let requesting_user = headers.user;
let Some(emergency_access) = let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -608,7 +614,7 @@ async fn takeover_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
@ -636,7 +642,7 @@ async fn password_emergency_access(
emer_id: EmergencyAccessId, emer_id: EmergencyAccessId,
data: Json<EmergencyAccessPasswordData>, data: Json<EmergencyAccessPasswordData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
) -> EmptyResult { ) -> EmptyResult {
check_emergency_access_enabled()?; check_emergency_access_enabled()?;
@ -646,7 +652,7 @@ async fn password_emergency_access(
let requesting_user = headers.user; let requesting_user = headers.user;
let Some(emergency_access) = let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -655,21 +661,21 @@ async fn password_emergency_access(
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(mut grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else { let Some(mut grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
// change grantor_user password // change grantor_user password
grantor_user.set_password(new_master_password_hash, Some(data.key), true, None); grantor_user.set_password(new_master_password_hash, Some(data.key), true, None);
grantor_user.save(&mut conn).await?; grantor_user.save(&conn).await?;
// Disable TwoFactor providers since they will otherwise block logins // Disable TwoFactor providers since they will otherwise block logins
TwoFactor::delete_all_by_user(&grantor_user.uuid, &mut conn).await?; TwoFactor::delete_all_by_user(&grantor_user.uuid, &conn).await?;
// Remove grantor from all organisations unless Owner // Remove grantor from all organisations unless Owner
for member in Membership::find_any_state_by_user(&grantor_user.uuid, &mut conn).await { for member in Membership::find_any_state_by_user(&grantor_user.uuid, &conn).await {
if member.atype != MembershipType::Owner as i32 { if member.atype != MembershipType::Owner as i32 {
member.delete(&mut conn).await?; member.delete(&conn).await?;
} }
} }
Ok(()) Ok(())
@ -678,10 +684,10 @@ async fn password_emergency_access(
// endregion // endregion
#[get("/emergency-access/<emer_id>/policies")] #[get("/emergency-access/<emer_id>/policies")]
async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
let requesting_user = headers.user; let requesting_user = headers.user;
let Some(emergency_access) = let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &mut conn).await EmergencyAccess::find_by_uuid_and_grantee_uuid(&emer_id, &requesting_user.uuid, &conn).await
else { else {
err!("Emergency access not valid.") err!("Emergency access not valid.")
}; };
@ -690,11 +696,11 @@ async fn policies_emergency_access(emer_id: EmergencyAccessId, headers: Headers,
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &mut conn).await else { let Some(grantor_user) = User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await else {
err!("Grantor user not found.") err!("Grantor user not found.")
}; };
let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &mut conn); let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &conn);
let policies_json: Vec<Value> = policies.await.iter().map(OrgPolicy::to_json).collect(); let policies_json: Vec<Value> = policies.await.iter().map(OrgPolicy::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
@ -728,8 +734,8 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
return; return;
} }
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {
debug!("No emergency request timeout to approve"); debug!("No emergency request timeout to approve");
@ -743,18 +749,18 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
if recovery_allowed_at.le(&now) { if recovery_allowed_at.le(&now) {
// Only update the access status // Only update the access status
// Updating the whole record could cause issues when the emergency_notification_reminder_job is also active // Updating the whole record could cause issues when the emergency_notification_reminder_job is also active
emer.update_access_status_and_save(EmergencyAccessStatus::RecoveryApproved as i32, &now, &mut conn) emer.update_access_status_and_save(EmergencyAccessStatus::RecoveryApproved as i32, &now, &conn)
.await .await
.expect("Unable to update emergency access status"); .expect("Unable to update emergency access status");
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// get grantor user to send Accepted email // get grantor user to send Accepted email
let grantor_user = let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &mut conn).await.expect("Grantor user not found"); User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found");
// get grantee user to send Accepted email // get grantee user to send Accepted email
let grantee_user = let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &mut conn) User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &conn)
.await .await
.expect("Grantee user not found"); .expect("Grantee user not found");
@ -783,8 +789,8 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
return; return;
} }
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {
debug!("No emergency request reminder notification to send"); debug!("No emergency request reminder notification to send");
@ -805,18 +811,18 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
if final_recovery_reminder_at.le(&now) && next_recovery_reminder_at.le(&now) { if final_recovery_reminder_at.le(&now) && next_recovery_reminder_at.le(&now) {
// Only update the last notification date // Only update the last notification date
// Updating the whole record could cause issues when the emergency_request_timeout_job is also active // Updating the whole record could cause issues when the emergency_request_timeout_job is also active
emer.update_last_notification_date_and_save(&now, &mut conn) emer.update_last_notification_date_and_save(&now, &conn)
.await .await
.expect("Unable to update emergency access notification date"); .expect("Unable to update emergency access notification date");
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// get grantor user to send Accepted email // get grantor user to send Accepted email
let grantor_user = let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &mut conn).await.expect("Grantor user not found"); User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found");
// get grantee user to send Accepted email // get grantee user to send Accepted email
let grantee_user = let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &mut conn) User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid"), &conn)
.await .await
.expect("Grantee user not found"); .expect("Grantee user not found");

41
src/api/core/events.rs

@ -31,12 +31,7 @@ struct EventRange {
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Controllers/EventsController.cs#L87 // Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Controllers/EventsController.cs#L87
#[get("/organizations/<org_id>/events?<data..>")] #[get("/organizations/<org_id>/events?<data..>")]
async fn get_org_events( async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: AdminHeaders, conn: DbConn) -> JsonResult {
org_id: OrganizationId,
data: EventRange,
headers: AdminHeaders,
mut conn: DbConn,
) -> JsonResult {
if org_id != headers.org_id { if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match"); err!("Organization not found", "Organization id's do not match");
} }
@ -53,7 +48,7 @@ async fn get_org_events(
parse_date(&data.end) parse_date(&data.end)
}; };
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &mut conn) Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn)
.await .await
.iter() .iter()
.map(|e| e.to_json()) .map(|e| e.to_json())
@ -68,14 +63,14 @@ async fn get_org_events(
} }
#[get("/ciphers/<cipher_id>/events?<data..>")] #[get("/ciphers/<cipher_id>/events?<data..>")]
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult {
// Return an empty vec when we org events are disabled. // Return an empty vec when we org events are disabled.
// This prevents client errors // This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() { let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0) Vec::with_capacity(0)
} else { } else {
let mut events_json = Vec::with_capacity(0); let mut events_json = Vec::with_capacity(0);
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &mut conn).await { if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await {
let start_date = parse_date(&data.start); let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token { let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date) parse_date(before_date)
@ -83,7 +78,7 @@ async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Heade
parse_date(&data.end) parse_date(&data.end)
}; };
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &mut conn) events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn)
.await .await
.iter() .iter()
.map(|e| e.to_json()) .map(|e| e.to_json())
@ -105,7 +100,7 @@ async fn get_user_events(
member_id: MembershipId, member_id: MembershipId,
data: EventRange, data: EventRange,
headers: AdminHeaders, headers: AdminHeaders,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
if org_id != headers.org_id { if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match"); err!("Organization not found", "Organization id's do not match");
@ -122,7 +117,7 @@ async fn get_user_events(
parse_date(&data.end) parse_date(&data.end)
}; };
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &mut conn) Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn)
.await .await
.iter() .iter()
.map(|e| e.to_json()) .map(|e| e.to_json())
@ -172,7 +167,7 @@ struct EventCollection {
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Events/Controllers/CollectController.cs // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Events/Controllers/CollectController.cs
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs
#[post("/collect", format = "application/json", data = "<data>")] #[post("/collect", format = "application/json", data = "<data>")]
async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.org_events_enabled() { if !CONFIG.org_events_enabled() {
return Ok(()); return Ok(());
} }
@ -187,7 +182,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype, headers.device.atype,
Some(event_date), Some(event_date),
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -201,14 +196,14 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype, headers.device.atype,
Some(event_date), Some(event_date),
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
} }
_ => { _ => {
if let Some(cipher_uuid) = &event.cipher_id { if let Some(cipher_uuid) = &event.cipher_id {
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &mut conn).await { if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await {
if let Some(org_id) = cipher.organization_uuid { if let Some(org_id) = cipher.organization_uuid {
_log_event( _log_event(
event.r#type, event.r#type,
@ -218,7 +213,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
headers.device.atype, headers.device.atype,
Some(event_date), Some(event_date),
&headers.ip.ip, &headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
@ -230,7 +225,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
Ok(()) Ok(())
} }
pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32, ip: &IpAddr, conn: &mut DbConn) { pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32, ip: &IpAddr, conn: &DbConn) {
if !CONFIG.org_events_enabled() { if !CONFIG.org_events_enabled() {
return; return;
} }
@ -243,7 +238,7 @@ async fn _log_user_event(
device_type: i32, device_type: i32,
event_date: Option<NaiveDateTime>, event_date: Option<NaiveDateTime>,
ip: &IpAddr, ip: &IpAddr,
conn: &mut DbConn, conn: &DbConn,
) { ) {
let memberships = Membership::find_by_user(user_id, conn).await; let memberships = Membership::find_by_user(user_id, conn).await;
let mut events: Vec<Event> = Vec::with_capacity(memberships.len() + 1); // We need an event per org and one without an org let mut events: Vec<Event> = Vec::with_capacity(memberships.len() + 1); // We need an event per org and one without an org
@ -278,7 +273,7 @@ pub async fn log_event(
act_user_id: &UserId, act_user_id: &UserId,
device_type: i32, device_type: i32,
ip: &IpAddr, ip: &IpAddr,
conn: &mut DbConn, conn: &DbConn,
) { ) {
if !CONFIG.org_events_enabled() { if !CONFIG.org_events_enabled() {
return; return;
@ -295,7 +290,7 @@ async fn _log_event(
device_type: i32, device_type: i32,
event_date: Option<NaiveDateTime>, event_date: Option<NaiveDateTime>,
ip: &IpAddr, ip: &IpAddr,
conn: &mut DbConn, conn: &DbConn,
) { ) {
// Create a new empty event // Create a new empty event
let mut event = Event::new(event_type, event_date); let mut event = Event::new(event_type, event_date);
@ -340,8 +335,8 @@ pub async fn event_cleanup_job(pool: DbPool) {
return; return;
} }
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Event::clean_events(&mut conn).await.ok(); Event::clean_events(&conn).await.ok();
} else { } else {
error!("Failed to get DB connection while trying to cleanup the events table") error!("Failed to get DB connection while trying to cleanup the events table")
} }

35
src/api/core/folders.rs

@ -4,7 +4,10 @@ use serde_json::Value;
use crate::{ use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType}, api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers, auth::Headers,
db::{models::*, DbConn}, db::{
models::{Folder, FolderId},
DbConn,
},
}; };
pub fn routes() -> Vec<rocket::Route> { pub fn routes() -> Vec<rocket::Route> {
@ -12,8 +15,8 @@ pub fn routes() -> Vec<rocket::Route> {
} }
#[get("/folders")] #[get("/folders")]
async fn get_folders(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
let folders = Folder::find_by_user(&headers.user.uuid, &mut conn).await; let folders = Folder::find_by_user(&headers.user.uuid, &conn).await;
let folders_json: Vec<Value> = folders.iter().map(Folder::to_json).collect(); let folders_json: Vec<Value> = folders.iter().map(Folder::to_json).collect();
Json(json!({ Json(json!({
@ -24,8 +27,8 @@ async fn get_folders(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[get("/folders/<folder_id>")] #[get("/folders/<folder_id>")]
async fn get_folder(folder_id: FolderId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult {
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await { match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Some(folder) => Ok(Json(folder.to_json())), Some(folder) => Ok(Json(folder.to_json())),
_ => err!("Invalid folder", "Folder does not exist or belongs to another user"), _ => err!("Invalid folder", "Folder does not exist or belongs to another user"),
} }
@ -39,13 +42,13 @@ pub struct FolderData {
} }
#[post("/folders", data = "<data>")] #[post("/folders", data = "<data>")]
async fn post_folders(data: Json<FolderData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn post_folders(data: Json<FolderData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let data: FolderData = data.into_inner(); let data: FolderData = data.into_inner();
let mut folder = Folder::new(headers.user.uuid, data.name); let mut folder = Folder::new(headers.user.uuid, data.name);
folder.save(&mut conn).await?; folder.save(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderCreate, &folder, &headers.device, &mut conn).await; nt.send_folder_update(UpdateType::SyncFolderCreate, &folder, &headers.device, &conn).await;
Ok(Json(folder.to_json())) Ok(Json(folder.to_json()))
} }
@ -66,19 +69,19 @@ async fn put_folder(
folder_id: FolderId, folder_id: FolderId,
data: Json<FolderData>, data: Json<FolderData>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let data: FolderData = data.into_inner(); let data: FolderData = data.into_inner();
let Some(mut folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await else { let Some(mut folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await else {
err!("Invalid folder", "Folder does not exist or belongs to another user") err!("Invalid folder", "Folder does not exist or belongs to another user")
}; };
folder.name = data.name; folder.name = data.name;
folder.save(&mut conn).await?; folder.save(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderUpdate, &folder, &headers.device, &mut conn).await; nt.send_folder_update(UpdateType::SyncFolderUpdate, &folder, &headers.device, &conn).await;
Ok(Json(folder.to_json())) Ok(Json(folder.to_json()))
} }
@ -89,14 +92,14 @@ async fn delete_folder_post(folder_id: FolderId, headers: Headers, conn: DbConn,
} }
#[delete("/folders/<folder_id>")] #[delete("/folders/<folder_id>")]
async fn delete_folder(folder_id: FolderId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_folder(folder_id: FolderId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &mut conn).await else { let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await else {
err!("Invalid folder", "Folder does not exist or belongs to another user") err!("Invalid folder", "Folder does not exist or belongs to another user")
}; };
// Delete the actual folder entry // Delete the actual folder entry
folder.delete(&mut conn).await?; folder.delete(&conn).await?;
nt.send_folder_update(UpdateType::SyncFolderDelete, &folder, &headers.device, &mut conn).await; nt.send_folder_update(UpdateType::SyncFolderDelete, &folder, &headers.device, &conn).await;
Ok(()) Ok(())
} }

18
src/api/core/mod.rs

@ -52,7 +52,10 @@ use rocket::{serde::json::Json, serde::json::Value, Catcher, Route};
use crate::{ use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType}, api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers, auth::Headers,
db::{models::*, DbConn}, db::{
models::{Membership, MembershipStatus, MembershipType, OrgPolicy, OrgPolicyErr, Organization, User},
DbConn,
},
error::Error, error::Error,
http_client::make_http_request, http_client::make_http_request,
mail, mail,
@ -106,12 +109,7 @@ struct EquivDomainData {
} }
#[post("/settings/domains", data = "<data>")] #[post("/settings/domains", data = "<data>")]
async fn post_eq_domains( async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
data: Json<EquivDomainData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: EquivDomainData = data.into_inner(); let data: EquivDomainData = data.into_inner();
let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default(); let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default();
@ -123,9 +121,9 @@ async fn post_eq_domains(
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string()); user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string()); user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string());
user.save(&mut conn).await?; user.save(&conn).await?;
nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &mut conn).await; nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &conn).await;
Ok(Json(json!({}))) Ok(Json(json!({})))
} }
@ -265,7 +263,7 @@ async fn accept_org_invite(
user: &User, user: &User,
mut member: Membership, mut member: Membership,
reset_password_key: Option<String>, reset_password_key: Option<String>,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
if member.status != MembershipStatus::Invited as i32 { if member.status != MembershipStatus::Invited as i32 {
err!("User already accepted the invitation"); err!("User already accepted the invitation");

659
src/api/core/organizations.rs

File diff suppressed because it is too large

53
src/api/core/public.rs

@ -10,7 +10,13 @@ use std::collections::HashSet;
use crate::{ use crate::{
api::EmptyResult, api::EmptyResult,
auth, auth,
db::{models::*, DbConn}, db::{
models::{
Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization,
OrganizationApiKey, OrganizationId, User,
},
DbConn,
},
mail, CONFIG, mail, CONFIG,
}; };
@ -44,7 +50,7 @@ struct OrgImportData {
} }
#[post("/public/organization/import", data = "<data>")] #[post("/public/organization/import", data = "<data>")]
async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: DbConn) -> EmptyResult { async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn) -> EmptyResult {
// Most of the logic for this function can be found here // Most of the logic for this function can be found here
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs#L1203 // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/OrganizationService.cs#L1203
@ -55,13 +61,12 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
let mut user_created: bool = false; let mut user_created: bool = false;
if user_data.deleted { if user_data.deleted {
// If user is marked for deletion and it exists, revoke it // If user is marked for deletion and it exists, revoke it
if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &mut conn).await { if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &conn).await {
// Only revoke a user if it is not the last confirmed owner // Only revoke a user if it is not the last confirmed owner
let revoked = if member.atype == MembershipType::Owner let revoked = if member.atype == MembershipType::Owner
&& member.status == MembershipStatus::Confirmed as i32 && member.status == MembershipStatus::Confirmed as i32
{ {
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &mut conn).await <= 1 if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 {
{
warn!("Can't revoke the last owner"); warn!("Can't revoke the last owner");
false false
} else { } else {
@ -73,27 +78,27 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
let ext_modified = member.set_external_id(Some(user_data.external_id.clone())); let ext_modified = member.set_external_id(Some(user_data.external_id.clone()));
if revoked || ext_modified { if revoked || ext_modified {
member.save(&mut conn).await?; member.save(&conn).await?;
} }
} }
// If user is part of the organization, restore it // If user is part of the organization, restore it
} else if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &mut conn).await { } else if let Some(mut member) = Membership::find_by_email_and_org(&user_data.email, &org_id, &conn).await {
let restored = member.restore(); let restored = member.restore();
let ext_modified = member.set_external_id(Some(user_data.external_id.clone())); let ext_modified = member.set_external_id(Some(user_data.external_id.clone()));
if restored || ext_modified { if restored || ext_modified {
member.save(&mut conn).await?; member.save(&conn).await?;
} }
} else { } else {
// If user is not part of the organization // If user is not part of the organization
let user = match User::find_by_mail(&user_data.email, &mut conn).await { let user = match User::find_by_mail(&user_data.email, &conn).await {
Some(user) => user, // exists in vaultwarden Some(user) => user, // exists in vaultwarden
None => { None => {
// User does not exist yet // User does not exist yet
let mut new_user = User::new(user_data.email.clone(), None); let mut new_user = User::new(user_data.email.clone(), None);
new_user.save(&mut conn).await?; new_user.save(&conn).await?;
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&mut conn).await?; Invitation::new(&new_user.email).save(&conn).await?;
} }
user_created = true; user_created = true;
new_user new_user
@ -105,7 +110,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites
}; };
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &mut conn).await { let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => (org.name, org.billing_email), Some(org) => (org.name, org.billing_email),
None => err!("Error looking up organization"), None => err!("Error looking up organization"),
}; };
@ -116,7 +121,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
new_member.atype = MembershipType::User as i32; new_member.atype = MembershipType::User as i32;
new_member.status = member_status; new_member.status = member_status;
new_member.save(&mut conn).await?; new_member.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Err(e) = if let Err(e) =
@ -124,9 +129,9 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
{ {
// Upon error delete the user, invite and org member records when needed // Upon error delete the user, invite and org member records when needed
if user_created { if user_created {
user.delete(&mut conn).await?; user.delete(&conn).await?;
} else { } else {
new_member.delete(&mut conn).await?; new_member.delete(&conn).await?;
} }
err!(format!("Error sending invite: {e:?} ")); err!(format!("Error sending invite: {e:?} "));
@ -137,8 +142,7 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
for group_data in &data.groups { for group_data in &data.groups {
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &mut conn).await let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await {
{
Some(group) => group.uuid, Some(group) => group.uuid,
None => { None => {
let mut group = Group::new( let mut group = Group::new(
@ -147,17 +151,17 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
false, false,
Some(group_data.external_id.clone()), Some(group_data.external_id.clone()),
); );
group.save(&mut conn).await?; group.save(&conn).await?;
group.uuid group.uuid
} }
}; };
GroupUser::delete_all_by_group(&group_uuid, &mut conn).await?; GroupUser::delete_all_by_group(&group_uuid, &conn).await?;
for ext_id in &group_data.member_external_ids { for ext_id in &group_data.member_external_ids {
if let Some(member) = Membership::find_by_external_id_and_org(ext_id, &org_id, &mut conn).await { if let Some(member) = Membership::find_by_external_id_and_org(ext_id, &org_id, &conn).await {
let mut group_user = GroupUser::new(group_uuid.clone(), member.uuid.clone()); let mut group_user = GroupUser::new(group_uuid.clone(), member.uuid.clone());
group_user.save(&mut conn).await?; group_user.save(&conn).await?;
} }
} }
} }
@ -169,19 +173,18 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, mut conn: Db
if data.overwrite_existing { if data.overwrite_existing {
// Generate a HashSet to quickly verify if a member is listed or not. // Generate a HashSet to quickly verify if a member is listed or not.
let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect(); let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect();
for member in Membership::find_by_org(&org_id, &mut conn).await { for member in Membership::find_by_org(&org_id, &conn).await {
if let Some(ref user_external_id) = member.external_id { if let Some(ref user_external_id) = member.external_id {
if !sync_members.contains(user_external_id) { if !sync_members.contains(user_external_id) {
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 { if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner // Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &mut conn).await if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1
<= 1
{ {
warn!("Can't delete the last owner"); warn!("Can't delete the last owner");
continue; continue;
} }
} }
member.delete(&mut conn).await?; member.delete(&conn).await?;
} }
} }
} }

129
src/api/core/sends.rs

@ -14,7 +14,10 @@ use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType}, api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
auth::{ClientIp, Headers, Host}, auth::{ClientIp, Headers, Host},
config::PathType, config::PathType,
db::{models::*, DbConn, DbPool}, db::{
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
DbConn, DbPool,
},
util::{save_temp_file, NumberOrString}, util::{save_temp_file, NumberOrString},
CONFIG, CONFIG,
}; };
@ -58,8 +61,8 @@ pub fn routes() -> Vec<rocket::Route> {
pub async fn purge_sends(pool: DbPool) { pub async fn purge_sends(pool: DbPool) {
debug!("Purging sends"); debug!("Purging sends");
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
Send::purge(&mut conn).await; Send::purge(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging sends") error!("Failed to get DB connection while purging sends")
} }
@ -96,7 +99,7 @@ pub struct SendData {
/// ///
/// There is also a Vaultwarden-specific `sends_allowed` config setting that /// There is also a Vaultwarden-specific `sends_allowed` config setting that
/// controls this policy globally. /// controls this policy globally.
async fn enforce_disable_send_policy(headers: &Headers, conn: &mut DbConn) -> EmptyResult { async fn enforce_disable_send_policy(headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_id = &headers.user.uuid; let user_id = &headers.user.uuid;
if !CONFIG.sends_allowed() if !CONFIG.sends_allowed()
|| OrgPolicy::is_applicable_to_user(user_id, OrgPolicyType::DisableSend, None, conn).await || OrgPolicy::is_applicable_to_user(user_id, OrgPolicyType::DisableSend, None, conn).await
@ -112,7 +115,7 @@ async fn enforce_disable_send_policy(headers: &Headers, conn: &mut DbConn) -> Em
/// but is allowed to remove this option from an existing Send. /// but is allowed to remove this option from an existing Send.
/// ///
/// Ref: https://bitwarden.com/help/article/policies/#send-options /// Ref: https://bitwarden.com/help/article/policies/#send-options
async fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &mut DbConn) -> EmptyResult { async fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_id = &headers.user.uuid; let user_id = &headers.user.uuid;
let hide_email = data.hide_email.unwrap_or(false); let hide_email = data.hide_email.unwrap_or(false);
if hide_email && OrgPolicy::is_hide_email_disabled(user_id, conn).await { if hide_email && OrgPolicy::is_hide_email_disabled(user_id, conn).await {
@ -164,8 +167,8 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult<Send> {
} }
#[get("/sends")] #[get("/sends")]
async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &mut conn); let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect(); let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect();
Json(json!({ Json(json!({
@ -176,32 +179,32 @@ async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[get("/sends/<send_id>")] #[get("/sends/<send_id>")]
async fn get_send(send_id: SendId, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await { match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Some(send) => Ok(Json(send.to_json())), Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid send uuid or does not belong to user"), None => err!("Send not found", "Invalid send uuid or does not belong to user"),
} }
} }
#[post("/sends", data = "<data>")] #[post("/sends", data = "<data>")]
async fn post_send(data: Json<SendData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn post_send(data: Json<SendData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &conn).await?;
let data: SendData = data.into_inner(); let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
if data.r#type == SendType::File as i32 { if data.r#type == SendType::File as i32 {
err!("File sends should use /api/sends/file") err!("File sends should use /api/sends/file")
} }
let mut send = create_send(data, headers.user.uuid)?; let mut send = create_send(data, headers.user.uuid)?;
send.save(&mut conn).await?; send.save(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendCreate, UpdateType::SyncSendCreate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&headers.device, &headers.device,
&mut conn, &conn,
) )
.await; .await;
@ -225,8 +228,8 @@ struct UploadDataV2<'f> {
// 2025: This endpoint doesn't seem to exists anymore in the latest version // 2025: This endpoint doesn't seem to exists anymore in the latest version
// See: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs // See: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs
#[post("/sends/file", format = "multipart/form-data", data = "<data>")] #[post("/sends/file", format = "multipart/form-data", data = "<data>")]
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &conn).await?;
let UploadData { let UploadData {
model, model,
@ -241,12 +244,12 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
err!("Send size can't be negative") err!("Send size can't be negative")
} }
enforce_disable_hide_email_policy(&model, &headers, &mut conn).await?; enforce_disable_hide_email_policy(&model, &headers, &conn).await?;
let size_limit = match CONFIG.user_send_limit() { let size_limit = match CONFIG.user_send_limit() {
Some(0) => err!("File uploads are disabled"), Some(0) => err!("File uploads are disabled"),
Some(limit_kb) => { Some(limit_kb) => {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &mut conn).await else { let Some(already_used) = Send::size_by_user(&headers.user.uuid, &conn).await else {
err!("Existing sends overflow") err!("Existing sends overflow")
}; };
let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else { let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else {
@ -282,13 +285,13 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
send.data = serde_json::to_string(&data_value)?; send.data = serde_json::to_string(&data_value)?;
// Save the changes in the database // Save the changes in the database
send.save(&mut conn).await?; send.save(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendCreate, UpdateType::SyncSendCreate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&headers.device, &headers.device,
&mut conn, &conn,
) )
.await; .await;
@ -297,8 +300,8 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs#L165 // Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/Tools/Controllers/SendsController.cs#L165
#[post("/sends/file/v2", data = "<data>")] #[post("/sends/file/v2", data = "<data>")]
async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn post_send_file_v2(data: Json<SendData>, headers: Headers, conn: DbConn) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &conn).await?;
let data = data.into_inner(); let data = data.into_inner();
@ -306,7 +309,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
err!("Send content is not a file"); err!("Send content is not a file");
} }
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let file_length = match &data.file_length { let file_length = match &data.file_length {
Some(m) => m.into_i64()?, Some(m) => m.into_i64()?,
@ -319,7 +322,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
let size_limit = match CONFIG.user_send_limit() { let size_limit = match CONFIG.user_send_limit() {
Some(0) => err!("File uploads are disabled"), Some(0) => err!("File uploads are disabled"),
Some(limit_kb) => { Some(limit_kb) => {
let Some(already_used) = Send::size_by_user(&headers.user.uuid, &mut conn).await else { let Some(already_used) = Send::size_by_user(&headers.user.uuid, &conn).await else {
err!("Existing sends overflow") err!("Existing sends overflow")
}; };
let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else { let Some(left) = limit_kb.checked_mul(1024).and_then(|l| l.checked_sub(already_used)) else {
@ -348,7 +351,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC
o.insert(String::from("sizeName"), Value::String(crate::util::get_display_size(file_length))); o.insert(String::from("sizeName"), Value::String(crate::util::get_display_size(file_length)));
} }
send.data = serde_json::to_string(&data_value)?; send.data = serde_json::to_string(&data_value)?;
send.save(&mut conn).await?; send.save(&conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"fileUploadType": 0, // 0 == Direct | 1 == Azure "fileUploadType": 0, // 0 == Direct | 1 == Azure
@ -373,14 +376,14 @@ async fn post_send_file_v2_data(
file_id: SendFileId, file_id: SendFileId,
data: Form<UploadDataV2<'_>>, data: Form<UploadDataV2<'_>>,
headers: Headers, headers: Headers,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> EmptyResult { ) -> EmptyResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &conn).await?;
let data = data.into_inner(); let data = data.into_inner();
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else { let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found. Unable to save the file.", "Invalid send uuid or does not belong to user.") err!("Send not found. Unable to save the file.", "Invalid send uuid or does not belong to user.")
}; };
@ -428,9 +431,9 @@ async fn post_send_file_v2_data(
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendCreate, UpdateType::SyncSendCreate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&headers.device, &headers.device,
&mut conn, &conn,
) )
.await; .await;
@ -447,11 +450,11 @@ pub struct SendAccessData {
async fn post_access( async fn post_access(
access_id: &str, access_id: &str,
data: Json<SendAccessData>, data: Json<SendAccessData>,
mut conn: DbConn, conn: DbConn,
ip: ClientIp, ip: ClientIp,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let Some(mut send) = Send::find_by_access_id(access_id, &mut conn).await else { let Some(mut send) = Send::find_by_access_id(access_id, &conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
}; };
@ -488,18 +491,18 @@ async fn post_access(
send.access_count += 1; send.access_count += 1;
} }
send.save(&mut conn).await?; send.save(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendUpdate, UpdateType::SyncSendUpdate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&ANON_PUSH_DEVICE, &ANON_PUSH_DEVICE,
&mut conn, &conn,
) )
.await; .await;
Ok(Json(send.to_json_access(&mut conn).await)) Ok(Json(send.to_json_access(&conn).await))
} }
#[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")] #[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")]
@ -508,10 +511,10 @@ async fn post_access_file(
file_id: SendFileId, file_id: SendFileId,
data: Json<SendAccessData>, data: Json<SendAccessData>,
host: Host, host: Host,
mut conn: DbConn, conn: DbConn,
nt: Notify<'_>, nt: Notify<'_>,
) -> JsonResult { ) -> JsonResult {
let Some(mut send) = Send::find_by_uuid(&send_id, &mut conn).await else { let Some(mut send) = Send::find_by_uuid(&send_id, &conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404) err_code!(SEND_INACCESSIBLE_MSG, 404)
}; };
@ -545,14 +548,14 @@ async fn post_access_file(
send.access_count += 1; send.access_count += 1;
send.save(&mut conn).await?; send.save(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendUpdate, UpdateType::SyncSendUpdate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&ANON_PUSH_DEVICE, &ANON_PUSH_DEVICE,
&mut conn, &conn,
) )
.await; .await;
@ -587,23 +590,17 @@ async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<
} }
#[put("/sends/<send_id>", data = "<data>")] #[put("/sends/<send_id>", data = "<data>")]
async fn put_send( async fn put_send(send_id: SendId, data: Json<SendData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
send_id: SendId, enforce_disable_send_policy(&headers, &conn).await?;
data: Json<SendData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;
let data: SendData = data.into_inner(); let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else { let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Send send_id is invalid or does not belong to user") err!("Send not found", "Send send_id is invalid or does not belong to user")
}; };
update_send_from_data(&mut send, data, &headers, &mut conn, &nt, UpdateType::SyncSendUpdate).await?; update_send_from_data(&mut send, data, &headers, &conn, &nt, UpdateType::SyncSendUpdate).await?;
Ok(Json(send.to_json())) Ok(Json(send.to_json()))
} }
@ -612,7 +609,7 @@ pub async fn update_send_from_data(
send: &mut Send, send: &mut Send,
data: SendData, data: SendData,
headers: &Headers, headers: &Headers,
conn: &mut DbConn, conn: &DbConn,
nt: &Notify<'_>, nt: &Notify<'_>,
ut: UpdateType, ut: UpdateType,
) -> EmptyResult { ) -> EmptyResult {
@ -667,18 +664,18 @@ pub async fn update_send_from_data(
} }
#[delete("/sends/<send_id>")] #[delete("/sends/<send_id>")]
async fn delete_send(send_id: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult { async fn delete_send(send_id: SendId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else { let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user") err!("Send not found", "Invalid send uuid, or does not belong to user")
}; };
send.delete(&mut conn).await?; send.delete(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendDelete, UpdateType::SyncSendDelete,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&headers.device, &headers.device,
&mut conn, &conn,
) )
.await; .await;
@ -686,21 +683,21 @@ async fn delete_send(send_id: SendId, headers: Headers, mut conn: DbConn, nt: No
} }
#[put("/sends/<send_id>/remove-password")] #[put("/sends/<send_id>/remove-password")]
async fn put_remove_password(send_id: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult { async fn put_remove_password(send_id: SendId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?; enforce_disable_send_policy(&headers, &conn).await?;
let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &mut conn).await else { let Some(mut send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user") err!("Send not found", "Invalid send uuid, or does not belong to user")
}; };
send.set_password(None); send.set_password(None);
send.save(&mut conn).await?; send.save(&conn).await?;
nt.send_send_update( nt.send_send_update(
UpdateType::SyncSendUpdate, UpdateType::SyncSendUpdate,
&send, &send,
&send.update_users_revision(&mut conn).await, &send.update_users_revision(&conn).await,
&headers.device, &headers.device,
&mut conn, &conn,
) )
.await; .await;

38
src/api/core/two_factor/authenticator.rs

@ -20,14 +20,14 @@ pub fn routes() -> Vec<Route> {
} }
#[post("/two-factor/get-authenticator", data = "<data>")] #[post("/two-factor/get-authenticator", data = "<data>")]
async fn generate_authenticator(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn generate_authenticator(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let type_ = TwoFactorType::Authenticator as i32; let type_ = TwoFactorType::Authenticator as i32;
let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await; let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await;
let (enabled, key) = match twofactor { let (enabled, key) = match twofactor {
Some(tf) => (true, tf.data), Some(tf) => (true, tf.data),
@ -55,7 +55,7 @@ struct EnableAuthenticatorData {
} }
#[post("/two-factor/authenticator", data = "<data>")] #[post("/two-factor/authenticator", data = "<data>")]
async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableAuthenticatorData = data.into_inner(); let data: EnableAuthenticatorData = data.into_inner();
let key = data.key; let key = data.key;
let token = data.token.into_string(); let token = data.token.into_string();
@ -66,7 +66,7 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
master_password_hash: data.master_password_hash, master_password_hash: data.master_password_hash,
otp: data.otp, otp: data.otp,
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
// Validate key as base32 and 20 bytes length // Validate key as base32 and 20 bytes length
@ -80,11 +80,11 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
} }
// Validate the token provided with the key, and save new twofactor // Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &mut conn).await?; validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?;
_generate_recover_code(&mut user, &mut conn).await; _generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"enabled": true, "enabled": true,
@ -103,7 +103,7 @@ pub async fn validate_totp_code_str(
totp_code: &str, totp_code: &str,
secret: &str, secret: &str,
ip: &ClientIp, ip: &ClientIp,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
if !totp_code.chars().all(char::is_numeric) { if !totp_code.chars().all(char::is_numeric) {
err!("TOTP code is not a number"); err!("TOTP code is not a number");
@ -117,7 +117,7 @@ pub async fn validate_totp_code(
totp_code: &str, totp_code: &str,
secret: &str, secret: &str,
ip: &ClientIp, ip: &ClientIp,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
use totp_lite::{totp_custom, Sha1}; use totp_lite::{totp_custom, Sha1};
@ -189,7 +189,7 @@ struct DisableAuthenticatorData {
} }
#[delete("/two-factor/authenticator", data = "<data>")] #[delete("/two-factor/authenticator", data = "<data>")]
async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: Headers, conn: DbConn) -> JsonResult {
let user = headers.user; let user = headers.user;
let type_ = data.r#type.into_i32()?; let type_ = data.r#type.into_i32()?;
@ -197,24 +197,18 @@ async fn disable_authenticator(data: Json<DisableAuthenticatorData>, headers: He
err!("Invalid password"); err!("Invalid password");
} }
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await { if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
if twofactor.data == data.key { if twofactor.data == data.key {
twofactor.delete(&mut conn).await?; twofactor.delete(&conn).await?;
log_user_event( log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
EventType::UserDisabled2fa as i32,
&user.uuid,
headers.device.atype,
&headers.ip.ip,
&mut conn,
)
.await; .await;
} else { } else {
err!(format!("TOTP key for user {} does not match recorded value, cannot deactivate", &user.email)); err!(format!("TOTP key for user {} does not match recorded value, cannot deactivate", &user.email));
} }
} }
if TwoFactor::find_by_user(&user.uuid, &mut conn).await.is_empty() { if TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty() {
super::enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await?; super::enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await?;
} }
Ok(Json(json!({ Ok(Json(json!({

24
src/api/core/two_factor/duo.rs

@ -92,13 +92,13 @@ impl DuoStatus {
const DISABLED_MESSAGE_DEFAULT: &str = "<To use the global Duo keys, please leave these fields untouched>"; const DISABLED_MESSAGE_DEFAULT: &str = "<To use the global Duo keys, please leave these fields untouched>";
#[post("/two-factor/get-duo", data = "<data>")] #[post("/two-factor/get-duo", data = "<data>")]
async fn get_duo(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_duo(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let data = get_user_duo_data(&user.uuid, &mut conn).await; let data = get_user_duo_data(&user.uuid, &conn).await;
let (enabled, data) = match data { let (enabled, data) = match data {
DuoStatus::Global(_) => (true, Some(DuoData::secret())), DuoStatus::Global(_) => (true, Some(DuoData::secret())),
@ -158,7 +158,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
} }
#[post("/two-factor/duo", data = "<data>")] #[post("/two-factor/duo", data = "<data>")]
async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableDuoData = data.into_inner(); let data: EnableDuoData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -166,7 +166,7 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbC
master_password_hash: data.master_password_hash.clone(), master_password_hash: data.master_password_hash.clone(),
otp: data.otp.clone(), otp: data.otp.clone(),
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
let (data, data_str) = if check_duo_fields_custom(&data) { let (data, data_str) = if check_duo_fields_custom(&data) {
@ -180,11 +180,11 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, mut conn: DbC
let type_ = TwoFactorType::Duo; let type_ = TwoFactorType::Duo;
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str); let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&mut conn).await?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await; _generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"enabled": true, "enabled": true,
@ -231,7 +231,7 @@ const AUTH_PREFIX: &str = "AUTH";
const DUO_PREFIX: &str = "TX"; const DUO_PREFIX: &str = "TX";
const APP_PREFIX: &str = "APP"; const APP_PREFIX: &str = "APP";
async fn get_user_duo_data(user_id: &UserId, conn: &mut DbConn) -> DuoStatus { async fn get_user_duo_data(user_id: &UserId, conn: &DbConn) -> DuoStatus {
let type_ = TwoFactorType::Duo as i32; let type_ = TwoFactorType::Duo as i32;
// If the user doesn't have an entry, disabled // If the user doesn't have an entry, disabled
@ -254,7 +254,7 @@ async fn get_user_duo_data(user_id: &UserId, conn: &mut DbConn) -> DuoStatus {
} }
// let (ik, sk, ak, host) = get_duo_keys(); // let (ik, sk, ak, host) = get_duo_keys();
pub(crate) async fn get_duo_keys_email(email: &str, conn: &mut DbConn) -> ApiResult<(String, String, String, String)> { pub(crate) async fn get_duo_keys_email(email: &str, conn: &DbConn) -> ApiResult<(String, String, String, String)> {
let data = match User::find_by_mail(email, conn).await { let data = match User::find_by_mail(email, conn).await {
Some(u) => get_user_duo_data(&u.uuid, conn).await.data(), Some(u) => get_user_duo_data(&u.uuid, conn).await.data(),
_ => DuoData::global(), _ => DuoData::global(),
@ -264,7 +264,7 @@ pub(crate) async fn get_duo_keys_email(email: &str, conn: &mut DbConn) -> ApiRes
Ok((data.ik, data.sk, CONFIG.get_duo_akey().await, data.host)) Ok((data.ik, data.sk, CONFIG.get_duo_akey().await, data.host))
} }
pub async fn generate_duo_signature(email: &str, conn: &mut DbConn) -> ApiResult<(String, String)> { pub async fn generate_duo_signature(email: &str, conn: &DbConn) -> ApiResult<(String, String)> {
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
let (ik, sk, ak, host) = get_duo_keys_email(email, conn).await?; let (ik, sk, ak, host) = get_duo_keys_email(email, conn).await?;
@ -282,7 +282,7 @@ fn sign_duo_values(key: &str, email: &str, ikey: &str, prefix: &str, expire: i64
format!("{cookie}|{}", crypto::hmac_sign(key, &cookie)) format!("{cookie}|{}", crypto::hmac_sign(key, &cookie))
} }
pub async fn validate_duo_login(email: &str, response: &str, conn: &mut DbConn) -> EmptyResult { pub async fn validate_duo_login(email: &str, response: &str, conn: &DbConn) -> EmptyResult {
let split: Vec<&str> = response.split(':').collect(); let split: Vec<&str> = response.split(':').collect();
if split.len() != 2 { if split.len() != 2 {
err!( err!(

10
src/api/core/two_factor/duo_oidc.rs

@ -317,7 +317,7 @@ struct DuoAuthContext {
// Given a state string, retrieve the associated Duo auth context and // Given a state string, retrieve the associated Duo auth context and
// delete the retrieved state from the database. // delete the retrieved state from the database.
async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContext> { async fn extract_context(state: &str, conn: &DbConn) -> Option<DuoAuthContext> {
let ctx: TwoFactorDuoContext = match TwoFactorDuoContext::find_by_state(state, conn).await { let ctx: TwoFactorDuoContext = match TwoFactorDuoContext::find_by_state(state, conn).await {
Some(c) => c, Some(c) => c,
None => return None, None => return None,
@ -344,8 +344,8 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
// Task to clean up expired Duo authentication contexts that may have accumulated in the database. // Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: DbPool) { pub async fn purge_duo_contexts(pool: DbPool) {
debug!("Purging Duo authentication contexts"); debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.get().await { if let Ok(conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await; TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging expired Duo authentications") error!("Failed to get DB connection while purging expired Duo authentications")
} }
@ -380,7 +380,7 @@ pub async fn get_duo_auth_url(
email: &str, email: &str,
client_id: &str, client_id: &str,
device_identifier: &DeviceId, device_identifier: &DeviceId,
conn: &mut DbConn, conn: &DbConn,
) -> Result<String, Error> { ) -> Result<String, Error> {
let (ik, sk, _, host) = get_duo_keys_email(email, conn).await?; let (ik, sk, _, host) = get_duo_keys_email(email, conn).await?;
@ -418,7 +418,7 @@ pub async fn validate_duo_login(
two_factor_token: &str, two_factor_token: &str,
client_id: &str, client_id: &str,
device_identifier: &DeviceId, device_identifier: &DeviceId,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
// Result supplied to us by clients in the form "<authz code>|<state>" // Result supplied to us by clients in the form "<authz code>|<state>"
let split: Vec<&str> = two_factor_token.split('|').collect(); let split: Vec<&str> = two_factor_token.split('|').collect();

42
src/api/core/two_factor/email.rs

@ -39,13 +39,13 @@ struct SendEmailLoginData {
/// User is trying to login and wants to use email 2FA. /// User is trying to login and wants to use email 2FA.
/// Does not require Bearer token /// Does not require Bearer token
#[post("/two-factor/send-email-login", data = "<data>")] // JsonResult #[post("/two-factor/send-email-login", data = "<data>")] // JsonResult
async fn send_email_login(data: Json<SendEmailLoginData>, mut conn: DbConn) -> EmptyResult { async fn send_email_login(data: Json<SendEmailLoginData>, conn: DbConn) -> EmptyResult {
let data: SendEmailLoginData = data.into_inner(); let data: SendEmailLoginData = data.into_inner();
use crate::db::models::User; use crate::db::models::User;
// Get the user // Get the user
let Some(user) = User::find_by_device_id(&data.device_identifier, &mut conn).await else { let Some(user) = User::find_by_device_id(&data.device_identifier, &conn).await else {
err!("Cannot find user. Try again.") err!("Cannot find user. Try again.")
}; };
@ -53,13 +53,13 @@ async fn send_email_login(data: Json<SendEmailLoginData>, mut conn: DbConn) -> E
err!("Email 2FA is disabled") err!("Email 2FA is disabled")
} }
send_token(&user.uuid, &mut conn).await?; send_token(&user.uuid, &conn).await?;
Ok(()) Ok(())
} }
/// Generate the token, save the data for later verification and send email to user /// Generate the token, save the data for later verification and send email to user
pub async fn send_token(user_id: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn send_token(user_id: &UserId, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::Email as i32; let type_ = TwoFactorType::Email as i32;
let mut twofactor = TwoFactor::find_by_user_and_type(user_id, type_, conn).await.map_res("Two factor not found")?; let mut twofactor = TwoFactor::find_by_user_and_type(user_id, type_, conn).await.map_res("Two factor not found")?;
@ -77,14 +77,14 @@ pub async fn send_token(user_id: &UserId, conn: &mut DbConn) -> EmptyResult {
/// When user clicks on Manage email 2FA show the user the related information /// When user clicks on Manage email 2FA show the user the related information
#[post("/two-factor/get-email", data = "<data>")] #[post("/two-factor/get-email", data = "<data>")]
async fn get_email(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_email(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let (enabled, mfa_email) = let (enabled, mfa_email) =
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &mut conn).await { match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &conn).await {
Some(x) => { Some(x) => {
let twofactor_data = EmailTokenData::from_json(&x.data)?; let twofactor_data = EmailTokenData::from_json(&x.data)?;
(true, json!(twofactor_data.email)) (true, json!(twofactor_data.email))
@ -110,7 +110,7 @@ struct SendEmailData {
/// Send a verification email to the specified email address to check whether it exists/belongs to user. /// Send a verification email to the specified email address to check whether it exists/belongs to user.
#[post("/two-factor/send-email", data = "<data>")] #[post("/two-factor/send-email", data = "<data>")]
async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn send_email(data: Json<SendEmailData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: SendEmailData = data.into_inner(); let data: SendEmailData = data.into_inner();
let user = headers.user; let user = headers.user;
@ -118,7 +118,7 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
master_password_hash: data.master_password_hash, master_password_hash: data.master_password_hash,
otp: data.otp, otp: data.otp,
} }
.validate(&user, false, &mut conn) .validate(&user, false, &conn)
.await?; .await?;
if !CONFIG._enable_email_2fa() { if !CONFIG._enable_email_2fa() {
@ -127,8 +127,8 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
let type_ = TwoFactorType::Email as i32; let type_ = TwoFactorType::Email as i32;
if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await { if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
tf.delete(&mut conn).await?; tf.delete(&conn).await?;
} }
let generated_token = crypto::generate_email_token(CONFIG.email_token_size()); let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
@ -136,7 +136,7 @@ async fn send_email(data: Json<SendEmailData>, headers: Headers, mut conn: DbCon
// Uses EmailVerificationChallenge as type to show that it's not verified yet. // Uses EmailVerificationChallenge as type to show that it's not verified yet.
let twofactor = TwoFactor::new(user.uuid, TwoFactorType::EmailVerificationChallenge, twofactor_data.to_json()); let twofactor = TwoFactor::new(user.uuid, TwoFactorType::EmailVerificationChallenge, twofactor_data.to_json());
twofactor.save(&mut conn).await?; twofactor.save(&conn).await?;
mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?).await?; mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?).await?;
@ -154,7 +154,7 @@ struct EmailData {
/// Verify email belongs to user and can be used for 2FA email codes. /// Verify email belongs to user and can be used for 2FA email codes.
#[put("/two-factor/email", data = "<data>")] #[put("/two-factor/email", data = "<data>")]
async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn email(data: Json<EmailData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EmailData = data.into_inner(); let data: EmailData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -163,12 +163,12 @@ async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> Jso
master_password_hash: data.master_password_hash, master_password_hash: data.master_password_hash,
otp: data.otp, otp: data.otp,
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
let type_ = TwoFactorType::EmailVerificationChallenge as i32; let type_ = TwoFactorType::EmailVerificationChallenge as i32;
let mut twofactor = let mut twofactor =
TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await.map_res("Two factor not found")?; TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await.map_res("Two factor not found")?;
let mut email_data = EmailTokenData::from_json(&twofactor.data)?; let mut email_data = EmailTokenData::from_json(&twofactor.data)?;
@ -183,11 +183,11 @@ async fn email(data: Json<EmailData>, headers: Headers, mut conn: DbConn) -> Jso
email_data.reset_token(); email_data.reset_token();
twofactor.atype = TwoFactorType::Email as i32; twofactor.atype = TwoFactorType::Email as i32;
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(&mut conn).await?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await; _generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"email": email_data.email, "email": email_data.email,
@ -202,7 +202,7 @@ pub async fn validate_email_code_str(
token: &str, token: &str,
data: &str, data: &str,
ip: &std::net::IpAddr, ip: &std::net::IpAddr,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let mut email_data = EmailTokenData::from_json(data)?; let mut email_data = EmailTokenData::from_json(data)?;
let mut twofactor = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Email as i32, conn) let mut twofactor = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Email as i32, conn)
@ -302,7 +302,7 @@ impl EmailTokenData {
} }
} }
pub async fn activate_email_2fa(user: &User, conn: &mut DbConn) -> EmptyResult { pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult {
if user.verified_at.is_none() { if user.verified_at.is_none() {
err!("Auto-enabling of email 2FA failed because the users email address has not been verified!"); err!("Auto-enabling of email 2FA failed because the users email address has not been verified!");
} }
@ -332,7 +332,7 @@ pub fn obscure_email(email: &str) -> String {
format!("{new_name}@{domain}") format!("{new_name}@{domain}")
} }
pub async fn find_and_activate_email_2fa(user_id: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn find_and_activate_email_2fa(user_id: &UserId, conn: &DbConn) -> EmptyResult {
if let Some(user) = User::find_by_uuid(user_id, conn).await { if let Some(user) = User::find_by_uuid(user_id, conn).await {
activate_email_2fa(&user, conn).await activate_email_2fa(&user, conn).await
} else { } else {

56
src/api/core/two_factor/mod.rs

@ -11,7 +11,13 @@ use crate::{
}, },
auth::{ClientHeaders, Headers}, auth::{ClientHeaders, Headers},
crypto, crypto,
db::{models::*, DbConn, DbPool}, db::{
models::{
DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor,
TwoFactorIncomplete, User, UserId,
},
DbConn, DbPool,
},
mail, mail,
util::NumberOrString, util::NumberOrString,
CONFIG, CONFIG,
@ -46,8 +52,8 @@ pub fn routes() -> Vec<Route> {
} }
#[get("/two-factor")] #[get("/two-factor")]
async fn get_twofactor(headers: Headers, mut conn: DbConn) -> Json<Value> { async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &mut conn).await; let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &conn).await;
let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect(); let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect();
Json(json!({ Json(json!({
@ -58,11 +64,11 @@ async fn get_twofactor(headers: Headers, mut conn: DbConn) -> Json<Value> {
} }
#[post("/two-factor/get-recover", data = "<data>")] #[post("/two-factor/get-recover", data = "<data>")]
async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, true, &mut conn).await?; data.validate(&user, true, &conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"code": user.totp_recover, "code": user.totp_recover,
@ -79,13 +85,13 @@ struct RecoverTwoFactor {
} }
#[post("/two-factor/recover", data = "<data>")] #[post("/two-factor/recover", data = "<data>")]
async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, mut conn: DbConn) -> JsonResult { async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, conn: DbConn) -> JsonResult {
let data: RecoverTwoFactor = data.into_inner(); let data: RecoverTwoFactor = data.into_inner();
use crate::db::models::User; use crate::db::models::User;
// Get the user // Get the user
let Some(mut user) = User::find_by_mail(&data.email, &mut conn).await else { let Some(mut user) = User::find_by_mail(&data.email, &conn).await else {
err!("Username or password is incorrect. Try again.") err!("Username or password is incorrect. Try again.")
}; };
@ -100,25 +106,25 @@ async fn recover(data: Json<RecoverTwoFactor>, client_headers: ClientHeaders, mu
} }
// Remove all twofactors from the user // Remove all twofactors from the user
TwoFactor::delete_all_by_user(&user.uuid, &mut conn).await?; TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
enforce_2fa_policy(&user, &user.uuid, client_headers.device_type, &client_headers.ip.ip, &mut conn).await?; enforce_2fa_policy(&user, &user.uuid, client_headers.device_type, &client_headers.ip.ip, &conn).await?;
log_user_event( log_user_event(
EventType::UserRecovered2fa as i32, EventType::UserRecovered2fa as i32,
&user.uuid, &user.uuid,
client_headers.device_type, client_headers.device_type,
&client_headers.ip.ip, &client_headers.ip.ip,
&mut conn, &conn,
) )
.await; .await;
// Remove the recovery code, not needed without twofactors // Remove the recovery code, not needed without twofactors
user.totp_recover = None; user.totp_recover = None;
user.save(&mut conn).await?; user.save(&conn).await?;
Ok(Json(Value::Object(serde_json::Map::new()))) Ok(Json(Value::Object(serde_json::Map::new())))
} }
async fn _generate_recover_code(user: &mut User, conn: &mut DbConn) { async fn _generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() { if user.totp_recover.is_none() {
let totp_recover = crypto::encode_random_bytes::<20>(BASE32); let totp_recover = crypto::encode_random_bytes::<20>(BASE32);
user.totp_recover = Some(totp_recover); user.totp_recover = Some(totp_recover);
@ -135,7 +141,7 @@ struct DisableTwoFactorData {
} }
#[post("/two-factor/disable", data = "<data>")] #[post("/two-factor/disable", data = "<data>")]
async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: DisableTwoFactorData = data.into_inner(); let data: DisableTwoFactorData = data.into_inner();
let user = headers.user; let user = headers.user;
@ -144,19 +150,19 @@ async fn disable_twofactor(data: Json<DisableTwoFactorData>, headers: Headers, m
master_password_hash: data.master_password_hash, master_password_hash: data.master_password_hash,
otp: data.otp, otp: data.otp,
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
let type_ = data.r#type.into_i32()?; let type_ = data.r#type.into_i32()?;
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await { if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
twofactor.delete(&mut conn).await?; twofactor.delete(&conn).await?;
log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn) log_user_event(EventType::UserDisabled2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await; .await;
} }
if TwoFactor::find_by_user(&user.uuid, &mut conn).await.is_empty() { if TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty() {
enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await?; enforce_2fa_policy(&user, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await?;
} }
Ok(Json(json!({ Ok(Json(json!({
@ -176,7 +182,7 @@ pub async fn enforce_2fa_policy(
act_user_id: &UserId, act_user_id: &UserId,
device_type: i32, device_type: i32,
ip: &std::net::IpAddr, ip: &std::net::IpAddr,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
for member in for member in
Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter() Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter()
@ -212,7 +218,7 @@ pub async fn enforce_2fa_policy_for_org(
act_user_id: &UserId, act_user_id: &UserId,
device_type: i32, device_type: i32,
ip: &std::net::IpAddr, ip: &std::net::IpAddr,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let org = Organization::find_by_uuid(org_id, conn).await.unwrap(); let org = Organization::find_by_uuid(org_id, conn).await.unwrap();
for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() { for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() {
@ -249,7 +255,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
return; return;
} }
let mut conn = match pool.get().await { let conn = match pool.get().await {
Ok(conn) => conn, Ok(conn) => conn,
_ => { _ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
@ -260,9 +266,9 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
let time_limit = TimeDelta::try_minutes(CONFIG.incomplete_2fa_time_limit()).unwrap(); let time_limit = TimeDelta::try_minutes(CONFIG.incomplete_2fa_time_limit()).unwrap();
let time_before = now - time_limit; let time_before = now - time_limit;
let incomplete_logins = TwoFactorIncomplete::find_logins_before(&time_before, &mut conn).await; let incomplete_logins = TwoFactorIncomplete::find_logins_before(&time_before, &conn).await;
for login in incomplete_logins { for login in incomplete_logins {
let user = User::find_by_uuid(&login.user_uuid, &mut conn).await.expect("User not found"); let user = User::find_by_uuid(&login.user_uuid, &conn).await.expect("User not found");
info!( info!(
"User {} did not complete a 2FA login within the configured time limit. IP: {}", "User {} did not complete a 2FA login within the configured time limit. IP: {}",
user.email, login.ip_address user.email, login.ip_address
@ -277,7 +283,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
.await .await
{ {
Ok(_) => { Ok(_) => {
if let Err(e) = login.delete(&mut conn).await { if let Err(e) = login.delete(&conn).await {
error!("Error deleting incomplete 2FA record: {e:#?}"); error!("Error deleting incomplete 2FA record: {e:#?}");
} }
} }

15
src/api/core/two_factor/protected_actions.rs

@ -55,7 +55,7 @@ impl ProtectedActionData {
} }
#[post("/accounts/request-otp")] #[post("/accounts/request-otp")]
async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult { async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
} }
@ -63,10 +63,9 @@ async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult {
let user = headers.user; let user = headers.user;
// Only one Protected Action per user is allowed to take place, delete the previous one // Only one Protected Action per user is allowed to take place, delete the previous one
if let Some(pa) = if let Some(pa) = TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::ProtectedActions as i32, &conn).await
TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::ProtectedActions as i32, &mut conn).await
{ {
pa.delete(&mut conn).await?; pa.delete(&conn).await?;
} }
let generated_token = crypto::generate_email_token(CONFIG.email_token_size()); let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
@ -74,7 +73,7 @@ async fn request_otp(headers: Headers, mut conn: DbConn) -> EmptyResult {
// Uses EmailVerificationChallenge as type to show that it's not verified yet. // Uses EmailVerificationChallenge as type to show that it's not verified yet.
let twofactor = TwoFactor::new(user.uuid, TwoFactorType::ProtectedActions, pa_data.to_json()); let twofactor = TwoFactor::new(user.uuid, TwoFactorType::ProtectedActions, pa_data.to_json());
twofactor.save(&mut conn).await?; twofactor.save(&conn).await?;
mail::send_protected_action_token(&user.email, &pa_data.token).await?; mail::send_protected_action_token(&user.email, &pa_data.token).await?;
@ -89,7 +88,7 @@ struct ProtectedActionVerify {
} }
#[post("/accounts/verify-otp", data = "<data>")] #[post("/accounts/verify-otp", data = "<data>")]
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, mut conn: DbConn) -> EmptyResult { async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device."); err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
} }
@ -99,14 +98,14 @@ async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, mut con
// Delete the token after one validation attempt // Delete the token after one validation attempt
// This endpoint only gets called for the vault export, and doesn't need a second attempt // This endpoint only gets called for the vault export, and doesn't need a second attempt
validate_protected_action_otp(&data.otp, &user.uuid, true, &mut conn).await validate_protected_action_otp(&data.otp, &user.uuid, true, &conn).await
} }
pub async fn validate_protected_action_otp( pub async fn validate_protected_action_otp(
otp: &str, otp: &str,
user_id: &UserId, user_id: &UserId,
delete_if_valid: bool, delete_if_valid: bool,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let pa = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::ProtectedActions as i32, conn) let pa = TwoFactor::find_by_user_and_type(user_id, TwoFactorType::ProtectedActions as i32, conn)
.await .await

49
src/api/core/two_factor/webauthn.rs

@ -107,7 +107,7 @@ impl WebauthnRegistration {
} }
#[post("/two-factor/get-webauthn", data = "<data>")] #[post("/two-factor/get-webauthn", data = "<data>")]
async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
if !CONFIG.domain_set() { if !CONFIG.domain_set() {
err!("`DOMAIN` environment variable is not set. Webauthn disabled") err!("`DOMAIN` environment variable is not set. Webauthn disabled")
} }
@ -115,9 +115,9 @@ async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn:
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let (enabled, registrations) = get_webauthn_registrations(&user.uuid, &mut conn).await?; let (enabled, registrations) = get_webauthn_registrations(&user.uuid, &conn).await?;
let registrations_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect(); let registrations_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
@ -128,13 +128,13 @@ async fn get_webauthn(data: Json<PasswordOrOtpData>, headers: Headers, mut conn:
} }
#[post("/two-factor/get-webauthn-challenge", data = "<data>")] #[post("/two-factor/get-webauthn-challenge", data = "<data>")]
async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let registrations = get_webauthn_registrations(&user.uuid, &mut conn) let registrations = get_webauthn_registrations(&user.uuid, &conn)
.await? .await?
.1 .1
.into_iter() .into_iter()
@ -153,7 +153,7 @@ async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Hea
state["rs"]["extensions"].as_object_mut().unwrap().clear(); state["rs"]["extensions"].as_object_mut().unwrap().clear();
let type_ = TwoFactorType::WebauthnRegisterChallenge; let type_ = TwoFactorType::WebauthnRegisterChallenge;
TwoFactor::new(user.uuid.clone(), type_, serde_json::to_string(&state)?).save(&mut conn).await?; TwoFactor::new(user.uuid.clone(), type_, serde_json::to_string(&state)?).save(&conn).await?;
// Because for this flow we abuse the passkeys as 2FA, and use it more like a securitykey // Because for this flow we abuse the passkeys as 2FA, and use it more like a securitykey
// we need to modify some of the default settings defined by `start_passkey_registration()`. // we need to modify some of the default settings defined by `start_passkey_registration()`.
@ -252,7 +252,7 @@ impl From<PublicKeyCredentialCopy> for PublicKeyCredential {
} }
#[post("/two-factor/webauthn", data = "<data>")] #[post("/two-factor/webauthn", data = "<data>")]
async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableWebauthnData = data.into_inner(); let data: EnableWebauthnData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -260,15 +260,15 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
master_password_hash: data.master_password_hash, master_password_hash: data.master_password_hash,
otp: data.otp, otp: data.otp,
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
// Retrieve and delete the saved challenge state // Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32; let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &mut conn).await { let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => { Some(tf) => {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?; let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&mut conn).await?; tf.delete(&conn).await?;
state state
} }
None => err!("Can't recover challenge"), None => err!("Can't recover challenge"),
@ -277,7 +277,7 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
// Verify the credentials with the saved state // Verify the credentials with the saved state
let credential = WEBAUTHN.finish_passkey_registration(&data.device_response.into(), &state)?; let credential = WEBAUTHN.finish_passkey_registration(&data.device_response.into(), &state)?;
let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &mut conn).await?.1; let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &conn).await?.1;
// TODO: Check for repeated ID's // TODO: Check for repeated ID's
registrations.push(WebauthnRegistration { registrations.push(WebauthnRegistration {
id: data.id.into_i32()?, id: data.id.into_i32()?,
@ -289,11 +289,11 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, mut
// Save the registrations and return them // Save the registrations and return them
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?) TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(&mut conn) .save(&conn)
.await?; .await?;
_generate_recover_code(&mut user, &mut conn).await; _generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
let keys_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect(); let keys_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
@ -316,14 +316,14 @@ struct DeleteU2FData {
} }
#[delete("/two-factor/webauthn", data = "<data>")] #[delete("/two-factor/webauthn", data = "<data>")]
async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
let id = data.id.into_i32()?; let id = data.id.into_i32()?;
if !headers.user.check_valid_password(&data.master_password_hash) { if !headers.user.check_valid_password(&data.master_password_hash) {
err!("Invalid password"); err!("Invalid password");
} }
let Some(mut tf) = let Some(mut tf) =
TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &mut conn).await TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &conn).await
else { else {
err!("Webauthn data not found!") err!("Webauthn data not found!")
}; };
@ -336,12 +336,11 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
let removed_item = data.remove(item_pos); let removed_item = data.remove(item_pos);
tf.data = serde_json::to_string(&data)?; tf.data = serde_json::to_string(&data)?;
tf.save(&mut conn).await?; tf.save(&conn).await?;
drop(tf); drop(tf);
// If entry is migrated from u2f, delete the u2f entry as well // If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) = if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &mut conn).await
{ {
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) { let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) {
Ok(d) => d, Ok(d) => d,
@ -352,7 +351,7 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
let new_data_str = serde_json::to_string(&data)?; let new_data_str = serde_json::to_string(&data)?;
u2f.data = new_data_str; u2f.data = new_data_str;
u2f.save(&mut conn).await?; u2f.save(&conn).await?;
} }
let keys_json: Vec<Value> = data.iter().map(WebauthnRegistration::to_json).collect(); let keys_json: Vec<Value> = data.iter().map(WebauthnRegistration::to_json).collect();
@ -366,7 +365,7 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, mut conn:
pub async fn get_webauthn_registrations( pub async fn get_webauthn_registrations(
user_id: &UserId, user_id: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Result<(bool, Vec<WebauthnRegistration>), Error> { ) -> Result<(bool, Vec<WebauthnRegistration>), Error> {
let type_ = TwoFactorType::Webauthn as i32; let type_ = TwoFactorType::Webauthn as i32;
match TwoFactor::find_by_user_and_type(user_id, type_, conn).await { match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
@ -375,7 +374,7 @@ pub async fn get_webauthn_registrations(
} }
} }
pub async fn generate_webauthn_login(user_id: &UserId, conn: &mut DbConn) -> JsonResult { pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonResult {
// Load saved credentials // Load saved credentials
let creds: Vec<Passkey> = let creds: Vec<Passkey> =
get_webauthn_registrations(user_id, conn).await?.1.into_iter().map(|r| r.credential).collect(); get_webauthn_registrations(user_id, conn).await?.1.into_iter().map(|r| r.credential).collect();
@ -415,7 +414,7 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &mut DbConn) -> Jso
Ok(Json(serde_json::to_value(response.public_key)?)) Ok(Json(serde_json::to_value(response.public_key)?))
} }
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &mut DbConn) -> EmptyResult { pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32; let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await { let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
Some(tf) => { Some(tf) => {
@ -469,7 +468,7 @@ async fn check_and_update_backup_eligible(
rsp: &PublicKeyCredential, rsp: &PublicKeyCredential,
registrations: &mut Vec<WebauthnRegistration>, registrations: &mut Vec<WebauthnRegistration>,
state: &mut PasskeyAuthentication, state: &mut PasskeyAuthentication,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
// The feature flags from the response // The feature flags from the response
// For details see: https://www.w3.org/TR/webauthn-3/#sctn-authenticator-data // For details see: https://www.w3.org/TR/webauthn-3/#sctn-authenticator-data

18
src/api/core/two_factor/yubikey.rs

@ -83,19 +83,19 @@ async fn verify_yubikey_otp(otp: String) -> EmptyResult {
} }
#[post("/two-factor/get-yubikey", data = "<data>")] #[post("/two-factor/get-yubikey", data = "<data>")]
async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
// Make sure the credentials are set // Make sure the credentials are set
get_yubico_credentials()?; get_yubico_credentials()?;
let data: PasswordOrOtpData = data.into_inner(); let data: PasswordOrOtpData = data.into_inner();
let user = headers.user; let user = headers.user;
data.validate(&user, false, &mut conn).await?; data.validate(&user, false, &conn).await?;
let user_id = &user.uuid; let user_id = &user.uuid;
let yubikey_type = TwoFactorType::YubiKey as i32; let yubikey_type = TwoFactorType::YubiKey as i32;
let r = TwoFactor::find_by_user_and_type(user_id, yubikey_type, &mut conn).await; let r = TwoFactor::find_by_user_and_type(user_id, yubikey_type, &conn).await;
if let Some(r) = r { if let Some(r) = r {
let yubikey_metadata: YubikeyMetadata = serde_json::from_str(&r.data)?; let yubikey_metadata: YubikeyMetadata = serde_json::from_str(&r.data)?;
@ -116,7 +116,7 @@ async fn generate_yubikey(data: Json<PasswordOrOtpData>, headers: Headers, mut c
} }
#[post("/two-factor/yubikey", data = "<data>")] #[post("/two-factor/yubikey", data = "<data>")]
async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut conn: DbConn) -> JsonResult { async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableYubikeyData = data.into_inner(); let data: EnableYubikeyData = data.into_inner();
let mut user = headers.user; let mut user = headers.user;
@ -124,12 +124,12 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut c
master_password_hash: data.master_password_hash.clone(), master_password_hash: data.master_password_hash.clone(),
otp: data.otp.clone(), otp: data.otp.clone(),
} }
.validate(&user, true, &mut conn) .validate(&user, true, &conn)
.await?; .await?;
// Check if we already have some data // Check if we already have some data
let mut yubikey_data = let mut yubikey_data =
match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &mut conn).await { match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &conn).await {
Some(data) => data, Some(data) => data,
None => TwoFactor::new(user.uuid.clone(), TwoFactorType::YubiKey, String::new()), None => TwoFactor::new(user.uuid.clone(), TwoFactorType::YubiKey, String::new()),
}; };
@ -160,11 +160,11 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, mut c
}; };
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap(); yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&mut conn).await?; yubikey_data.save(&conn).await?;
_generate_recover_code(&mut user, &mut conn).await; _generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &mut conn).await; log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
let mut result = jsonify_yubikeys(yubikey_metadata.keys); let mut result = jsonify_yubikeys(yubikey_metadata.keys);

59
src/api/identity.rs

@ -22,7 +22,13 @@ use crate::{
}, },
auth, auth,
auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion}, auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion},
db::{models::*, DbConn}, db::{
models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OrganizationApiKey, OrganizationId,
SsoNonce, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn,
},
error::MapResult, error::MapResult,
mail, sso, mail, sso,
sso::{OIDCCode, OIDCState}, sso::{OIDCCode, OIDCState},
@ -48,7 +54,7 @@ async fn login(
data: Form<ConnectData>, data: Form<ConnectData>,
client_header: ClientHeaders, client_header: ClientHeaders,
client_version: Option<ClientVersion>, client_version: Option<ClientVersion>,
mut conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let data: ConnectData = data.into_inner(); let data: ConnectData = data.into_inner();
@ -57,7 +63,7 @@ async fn login(
let login_result = match data.grant_type.as_ref() { let login_result = match data.grant_type.as_ref() {
"refresh_token" => { "refresh_token" => {
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?; _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, &mut conn, &client_header.ip).await _refresh_login(data, &conn, &client_header.ip).await
} }
"password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"), "password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"),
"password" => { "password" => {
@ -70,7 +76,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?; _check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?; _check_is_some(&data.device_type, "device_type cannot be blank")?;
_password_login(data, &mut user_id, &mut conn, &client_header.ip, &client_version).await _password_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
} }
"client_credentials" => { "client_credentials" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?; _check_is_some(&data.client_id, "client_id cannot be blank")?;
@ -81,7 +87,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?; _check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?; _check_is_some(&data.device_type, "device_type cannot be blank")?;
_api_key_login(data, &mut user_id, &mut conn, &client_header.ip).await _api_key_login(data, &mut user_id, &conn, &client_header.ip).await
} }
"authorization_code" if CONFIG.sso_enabled() => { "authorization_code" if CONFIG.sso_enabled() => {
_check_is_some(&data.client_id, "client_id cannot be blank")?; _check_is_some(&data.client_id, "client_id cannot be blank")?;
@ -91,7 +97,7 @@ async fn login(
_check_is_some(&data.device_name, "device_name cannot be blank")?; _check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?; _check_is_some(&data.device_type, "device_type cannot be blank")?;
_sso_login(data, &mut user_id, &mut conn, &client_header.ip, &client_version).await _sso_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
} }
"authorization_code" => err!("SSO sign-in is not available"), "authorization_code" => err!("SSO sign-in is not available"),
t => err!("Invalid type", t), t => err!("Invalid type", t),
@ -105,19 +111,13 @@ async fn login(
&user_id, &user_id,
client_header.device_type, client_header.device_type,
&client_header.ip.ip, &client_header.ip.ip,
&mut conn, &conn,
) )
.await; .await;
} }
Err(e) => { Err(e) => {
if let Some(ev) = e.get_event() { if let Some(ev) = e.get_event() {
log_user_event( log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn)
ev.event as i32,
&user_id,
client_header.device_type,
&client_header.ip.ip,
&mut conn,
)
.await .await
} }
} }
@ -128,7 +128,7 @@ async fn login(
} }
// Return Status::Unauthorized to trigger logout // Return Status::Unauthorized to trigger logout
async fn _refresh_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) -> JsonResult { async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Extract token // Extract token
let refresh_token = match data.refresh_token { let refresh_token = match data.refresh_token {
Some(token) => token, Some(token) => token,
@ -166,7 +166,7 @@ async fn _refresh_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) ->
async fn _sso_login( async fn _sso_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &mut DbConn, conn: &DbConn,
ip: &ClientIp, ip: &ClientIp,
client_version: &Option<ClientVersion>, client_version: &Option<ClientVersion>,
) -> JsonResult { ) -> JsonResult {
@ -319,7 +319,7 @@ async fn _sso_login(
async fn _password_login( async fn _password_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &mut DbConn, conn: &DbConn,
ip: &ClientIp, ip: &ClientIp,
client_version: &Option<ClientVersion>, client_version: &Option<ClientVersion>,
) -> JsonResult { ) -> JsonResult {
@ -444,7 +444,7 @@ async fn authenticated_response(
auth_tokens: auth::AuthTokens, auth_tokens: auth::AuthTokens,
twofactor_token: Option<String>, twofactor_token: Option<String>,
now: &NaiveDateTime, now: &NaiveDateTime,
conn: &mut DbConn, conn: &DbConn,
ip: &ClientIp, ip: &ClientIp,
) -> JsonResult { ) -> JsonResult {
if CONFIG.mail_enabled() && device.is_new() { if CONFIG.mail_enabled() && device.is_new() {
@ -504,12 +504,7 @@ async fn authenticated_response(
Ok(Json(result)) Ok(Json(result))
} }
async fn _api_key_login( async fn _api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &mut DbConn,
ip: &ClientIp,
) -> JsonResult {
// Ratelimit the login // Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?; crate::ratelimit::check_limit_login(&ip.ip)?;
@ -524,7 +519,7 @@ async fn _api_key_login(
async fn _user_api_key_login( async fn _user_api_key_login(
data: ConnectData, data: ConnectData,
user_id: &mut Option<UserId>, user_id: &mut Option<UserId>,
conn: &mut DbConn, conn: &DbConn,
ip: &ClientIp, ip: &ClientIp,
) -> JsonResult { ) -> JsonResult {
// Get the user via the client_id // Get the user via the client_id
@ -614,7 +609,7 @@ async fn _user_api_key_login(
Ok(Json(result)) Ok(Json(result))
} }
async fn _organization_api_key_login(data: ConnectData, conn: &mut DbConn, ip: &ClientIp) -> JsonResult { async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Get the org via the client_id // Get the org via the client_id
let client_id = data.client_id.as_ref().unwrap(); let client_id = data.client_id.as_ref().unwrap();
let Some(org_id) = client_id.strip_prefix("organization.") else { let Some(org_id) = client_id.strip_prefix("organization.") else {
@ -643,7 +638,7 @@ async fn _organization_api_key_login(data: ConnectData, conn: &mut DbConn, ip: &
} }
/// Retrieves an existing device or creates a new device from ConnectData and the User /// Retrieves an existing device or creates a new device from ConnectData and the User
async fn get_device(data: &ConnectData, conn: &mut DbConn, user: &User) -> ApiResult<Device> { async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult<Device> {
// On iOS, device_type sends "iOS", on others it sends a number // On iOS, device_type sends "iOS", on others it sends a number
// When unknown or unable to parse, return 14, which is 'Unknown Browser' // When unknown or unable to parse, return 14, which is 'Unknown Browser'
let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(14); let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(14);
@ -663,7 +658,7 @@ async fn twofactor_auth(
device: &mut Device, device: &mut Device,
ip: &ClientIp, ip: &ClientIp,
client_version: &Option<ClientVersion>, client_version: &Option<ClientVersion>,
conn: &mut DbConn, conn: &DbConn,
) -> ApiResult<Option<String>> { ) -> ApiResult<Option<String>> {
let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await; let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await;
@ -780,7 +775,7 @@ async fn _json_err_twofactor(
user_id: &UserId, user_id: &UserId,
data: &ConnectData, data: &ConnectData,
client_version: &Option<ClientVersion>, client_version: &Option<ClientVersion>,
conn: &mut DbConn, conn: &DbConn,
) -> ApiResult<Value> { ) -> ApiResult<Value> {
let mut result = json!({ let mut result = json!({
"error" : "invalid_grant", "error" : "invalid_grant",
@ -905,13 +900,13 @@ enum RegisterVerificationResponse {
#[post("/accounts/register/send-verification-email", data = "<data>")] #[post("/accounts/register/send-verification-email", data = "<data>")]
async fn register_verification_email( async fn register_verification_email(
data: Json<RegisterVerificationData>, data: Json<RegisterVerificationData>,
mut conn: DbConn, conn: DbConn,
) -> ApiResult<RegisterVerificationResponse> { ) -> ApiResult<RegisterVerificationResponse> {
let data = data.into_inner(); let data = data.into_inner();
// the registration can only continue if signup is allowed or there exists an invitation // the registration can only continue if signup is allowed or there exists an invitation
if !(CONFIG.is_signup_allowed(&data.email) if !(CONFIG.is_signup_allowed(&data.email)
|| (!CONFIG.mail_enabled() && Invitation::find_by_mail(&data.email, &mut conn).await.is_some())) || (!CONFIG.mail_enabled() && Invitation::find_by_mail(&data.email, &conn).await.is_some()))
{ {
err!("Registration not allowed or user already exists") err!("Registration not allowed or user already exists")
} }
@ -922,7 +917,7 @@ async fn register_verification_email(
let token = auth::encode_jwt(&token_claims); let token = auth::encode_jwt(&token_claims);
if should_send_mail { if should_send_mail {
let user = User::find_by_mail(&data.email, &mut conn).await; let user = User::find_by_mail(&data.email, &conn).await;
if user.filter(|u| u.private_key.is_some()).is_some() { if user.filter(|u| u.private_key.is_some()).is_some() {
// There is still a timing side channel here in that the code // There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that // paths that send mail take noticeably longer than ones that

2
src/api/mod.rs

@ -55,7 +55,7 @@ impl PasswordOrOtpData {
/// Tokens used via this struct can be used multiple times during the process /// Tokens used via this struct can be used multiple times during the process
/// First for the validation to continue, after that to enable or validate the following actions /// First for the validation to continue, after that to enable or validate the following actions
/// This is different per caller, so it can be adjusted to delete the token or not /// This is different per caller, so it can be adjusted to delete the token or not
pub async fn validate(&self, user: &User, delete_if_valid: bool, conn: &mut DbConn) -> EmptyResult { pub async fn validate(&self, user: &User, delete_if_valid: bool, conn: &DbConn) -> EmptyResult {
use crate::api::core::two_factor::protected_actions::validate_protected_action_otp; use crate::api::core::two_factor::protected_actions::validate_protected_action_otp;
match (self.master_password_hash.as_deref(), self.otp.as_deref()) { match (self.master_password_hash.as_deref(), self.otp.as_deref()) {

20
src/api/notifications.rs

@ -339,7 +339,7 @@ impl WebSocketUsers {
} }
// NOTE: The last modified date needs to be updated before calling these methods // NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &mut DbConn) { pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
return; return;
@ -359,7 +359,7 @@ impl WebSocketUsers {
} }
} }
pub async fn send_logout(&self, user: &User, acting_device_id: Option<DeviceId>, conn: &mut DbConn) { pub async fn send_logout(&self, user: &User, acting_device_id: Option<DeviceId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
return; return;
@ -379,7 +379,7 @@ impl WebSocketUsers {
} }
} }
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder, device: &Device, conn: &mut DbConn) { pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder, device: &Device, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
return; return;
@ -410,7 +410,7 @@ impl WebSocketUsers {
user_ids: &[UserId], user_ids: &[UserId],
device: &Device, device: &Device,
collection_uuids: Option<Vec<CollectionId>>, collection_uuids: Option<Vec<CollectionId>>,
conn: &mut DbConn, conn: &DbConn,
) { ) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
@ -458,7 +458,7 @@ impl WebSocketUsers {
send: &DbSend, send: &DbSend,
user_ids: &[UserId], user_ids: &[UserId],
device: &Device, device: &Device,
conn: &mut DbConn, conn: &DbConn,
) { ) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
@ -486,13 +486,7 @@ impl WebSocketUsers {
} }
} }
pub async fn send_auth_request( pub async fn send_auth_request(&self, user_id: &UserId, auth_request_uuid: &str, device: &Device, conn: &DbConn) {
&self,
user_id: &UserId,
auth_request_uuid: &str,
device: &Device,
conn: &mut DbConn,
) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {
return; return;
@ -516,7 +510,7 @@ impl WebSocketUsers {
user_id: &UserId, user_id: &UserId,
auth_request_id: &AuthRequestId, auth_request_id: &AuthRequestId,
device: &Device, device: &Device,
conn: &mut DbConn, conn: &DbConn,
) { ) {
// Skip any processing if both WebSockets and Push are not active // Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED { if *NOTIFICATIONS_DISABLED {

26
src/api/push.rs

@ -7,7 +7,10 @@ use tokio::sync::RwLock;
use crate::{ use crate::{
api::{ApiResult, EmptyResult, UpdateType}, api::{ApiResult, EmptyResult, UpdateType},
db::models::{AuthRequestId, Cipher, Device, DeviceId, Folder, PushId, Send, User, UserId}, db::{
models::{AuthRequestId, Cipher, Device, DeviceId, Folder, PushId, Send, User, UserId},
DbConn,
},
http_client::make_http_request, http_client::make_http_request,
util::{format_date, get_uuid}, util::{format_date, get_uuid},
CONFIG, CONFIG,
@ -79,7 +82,7 @@ async fn get_auth_api_token() -> ApiResult<String> {
Ok(api_token.access_token.clone()) Ok(api_token.access_token.clone())
} }
pub async fn register_push_device(device: &mut Device, conn: &mut crate::db::DbConn) -> EmptyResult { pub async fn register_push_device(device: &mut Device, conn: &DbConn) -> EmptyResult {
if !CONFIG.push_enabled() || !device.is_push_device() { if !CONFIG.push_enabled() || !device.is_push_device() {
return Ok(()); return Ok(());
} }
@ -152,7 +155,7 @@ pub async fn unregister_push_device(push_id: &Option<PushId>) -> EmptyResult {
Ok(()) Ok(())
} }
pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device, conn: &mut crate::db::DbConn) { pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device, conn: &DbConn) {
// We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too. // We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too.
if cipher.organization_uuid.is_some() { if cipher.organization_uuid.is_some() {
return; return;
@ -183,7 +186,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device
} }
} }
pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn: &mut crate::db::DbConn) { pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn: &DbConn) {
let acting_device_id: Value = acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| Value::Null); let acting_device_id: Value = acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| Value::Null);
if Device::check_user_has_push_device(&user.uuid, conn).await { if Device::check_user_has_push_device(&user.uuid, conn).await {
@ -203,7 +206,7 @@ pub async fn push_logout(user: &User, acting_device_id: Option<DeviceId>, conn:
} }
} }
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &mut crate::db::DbConn) { pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
if Device::check_user_has_push_device(&user.uuid, conn).await { if Device::check_user_has_push_device(&user.uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
"userId": user.uuid, "userId": user.uuid,
@ -221,7 +224,7 @@ pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<Pu
} }
} }
pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device, conn: &mut crate::db::DbConn) { pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device, conn: &DbConn) {
if Device::check_user_has_push_device(&folder.user_uuid, conn).await { if Device::check_user_has_push_device(&folder.user_uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
"userId": folder.user_uuid, "userId": folder.user_uuid,
@ -240,7 +243,7 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device
} }
} }
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &mut crate::db::DbConn) { pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) {
if let Some(s) = &send.user_uuid { if let Some(s) = &send.user_uuid {
if Device::check_user_has_push_device(s, conn).await { if Device::check_user_has_push_device(s, conn).await {
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
@ -296,7 +299,7 @@ async fn send_to_push_relay(notification_data: Value) {
}; };
} }
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &mut crate::db::DbConn) { pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) {
if Device::check_user_has_push_device(user_id, conn).await { if Device::check_user_has_push_device(user_id, conn).await {
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
"userId": user_id, "userId": user_id,
@ -314,12 +317,7 @@ pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device:
} }
} }
pub async fn push_auth_response( pub async fn push_auth_response(user_id: &UserId, auth_request_id: &AuthRequestId, device: &Device, conn: &DbConn) {
user_id: &UserId,
auth_request_id: &AuthRequestId,
device: &Device,
conn: &mut crate::db::DbConn,
) {
if Device::check_user_has_push_device(user_id, conn).await { if Device::check_user_has_push_device(user_id, conn).await {
tokio::task::spawn(send_to_push_relay(json!({ tokio::task::spawn(send_to_push_relay(json!({
"userId": user_id, "userId": user_id,

20
src/auth.rs

@ -604,16 +604,16 @@ impl<'r> FromRequest<'r> for Headers {
let device_id = claims.device; let device_id = claims.device;
let user_id = claims.sub; let user_id = claims.sub;
let mut conn = match DbConn::from_request(request).await { let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn, Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), _ => err_handler!("Error getting DB"),
}; };
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &mut conn).await else { let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else {
err_handler!("Invalid device id") err_handler!("Invalid device id")
}; };
let Some(user) = User::find_by_uuid(&user_id, &mut conn).await else { let Some(user) = User::find_by_uuid(&user_id, &conn).await else {
err_handler!("Device has no user associated") err_handler!("Device has no user associated")
}; };
@ -633,7 +633,7 @@ impl<'r> FromRequest<'r> for Headers {
// This prevents checking this stamp exception for new requests. // This prevents checking this stamp exception for new requests.
let mut user = user; let mut user = user;
user.reset_stamp_exception(); user.reset_stamp_exception();
if let Err(e) = user.save(&mut conn).await { if let Err(e) = user.save(&conn).await {
error!("Error updating user: {e:#?}"); error!("Error updating user: {e:#?}");
} }
err_handler!("Stamp exception is expired") err_handler!("Stamp exception is expired")
@ -706,13 +706,13 @@ impl<'r> FromRequest<'r> for OrgHeaders {
match url_org_id { match url_org_id {
Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => { Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => {
let mut conn = match DbConn::from_request(request).await { let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn, Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), _ => err_handler!("Error getting DB"),
}; };
let user = headers.user; let user = headers.user;
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org_id, &mut conn).await else { let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await else {
err_handler!("The current user isn't member of the organization"); err_handler!("The current user isn't member of the organization");
}; };
@ -815,12 +815,12 @@ impl<'r> FromRequest<'r> for ManagerHeaders {
if headers.is_confirmed_and_manager() { if headers.is_confirmed_and_manager() {
match get_col_id(request) { match get_col_id(request) {
Some(col_id) => { Some(col_id) => {
let mut conn = match DbConn::from_request(request).await { let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn, Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), _ => err_handler!("Error getting DB"),
}; };
if !Collection::can_access_collection(&headers.membership, &col_id, &mut conn).await { if !Collection::can_access_collection(&headers.membership, &col_id, &conn).await {
err_handler!("The current user isn't a manager for this collection") err_handler!("The current user isn't a manager for this collection")
} }
} }
@ -896,7 +896,7 @@ impl ManagerHeaders {
pub async fn from_loose( pub async fn from_loose(
h: ManagerHeadersLoose, h: ManagerHeadersLoose,
collections: &Vec<CollectionId>, collections: &Vec<CollectionId>,
conn: &mut DbConn, conn: &DbConn,
) -> Result<ManagerHeaders, Error> { ) -> Result<ManagerHeaders, Error> {
for col_id in collections { for col_id in collections {
if uuid::Uuid::parse_str(col_id.as_ref()).is_err() { if uuid::Uuid::parse_str(col_id.as_ref()).is_err() {
@ -1200,7 +1200,7 @@ pub async fn refresh_tokens(
ip: &ClientIp, ip: &ClientIp,
refresh_token: &str, refresh_token: &str,
client_id: Option<String>, client_id: Option<String>,
conn: &mut DbConn, conn: &DbConn,
) -> ApiResult<(Device, AuthTokens)> { ) -> ApiResult<(Device, AuthTokens)> {
let refresh_claims = match decode_refresh(refresh_token) { let refresh_claims = match decode_refresh(refresh_token) {
Err(err) => { Err(err) => {

12
src/config.rs

@ -12,7 +12,6 @@ use once_cell::sync::Lazy;
use reqwest::Url; use reqwest::Url;
use crate::{ use crate::{
db::DbConnType,
error::Error, error::Error,
util::{get_env, get_env_bool, get_web_vault_version, is_valid_email, parse_experimental_client_feature_flags}, util::{get_env, get_env_bool, get_web_vault_version, is_valid_email, parse_experimental_client_feature_flags},
}; };
@ -815,12 +814,19 @@ make_config! {
fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
// Validate connection URL is valid and DB feature is enabled // Validate connection URL is valid and DB feature is enabled
#[cfg(sqlite)]
{
use crate::db::DbConnType;
let url = &cfg.database_url; let url = &cfg.database_url;
if DbConnType::from_url(url)? == DbConnType::sqlite && url.contains('/') { if DbConnType::from_url(url)? == DbConnType::Sqlite && url.contains('/') {
let path = std::path::Path::new(&url); let path = std::path::Path::new(&url);
if let Some(parent) = path.parent() { if let Some(parent) = path.parent() {
if !parent.is_dir() { if !parent.is_dir() {
err!(format!("SQLite database directory `{}` does not exist or is not a directory", parent.display())); err!(format!(
"SQLite database directory `{}` does not exist or is not a directory",
parent.display()
));
}
} }
} }
} }

379
src/db/mod.rs

@ -1,8 +1,14 @@
use std::{sync::Arc, time::Duration}; mod query_logger;
use std::{
sync::{Arc, OnceLock},
time::Duration,
};
use diesel::{ use diesel::{
connection::SimpleConnection, connection::SimpleConnection,
r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection}, r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
}; };
use rocket::{ use rocket::{
@ -21,20 +27,7 @@ use crate::{
CONFIG, CONFIG,
}; };
#[cfg(sqlite)]
#[path = "schemas/sqlite/schema.rs"]
pub mod __sqlite_schema;
#[cfg(mysql)]
#[path = "schemas/mysql/schema.rs"]
pub mod __mysql_schema;
#[cfg(postgresql)]
#[path = "schemas/postgresql/schema.rs"]
pub mod __postgresql_schema;
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
// A wrapper around spawn_blocking that propagates panics to the calling code. // A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R pub async fn run_blocking<F, R>(job: F) -> R
where where
@ -51,47 +44,54 @@ where
} }
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported // This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules! generate_connections { #[derive(diesel::MultiConnection)]
( $( $name:ident: $ty:ty ),+ ) => { pub enum DbConnInner {
#[allow(non_camel_case_types, dead_code)] #[cfg(mysql)]
Mysql(diesel::mysql::MysqlConnection),
#[cfg(postgresql)]
Postgresql(diesel::pg::PgConnection),
#[cfg(sqlite)]
Sqlite(diesel::sqlite::SqliteConnection),
}
#[derive(Eq, PartialEq)] #[derive(Eq, PartialEq)]
pub enum DbConnType { $( $name, )+ } pub enum DbConnType {
#[cfg(mysql)]
Mysql,
#[cfg(postgresql)]
Postgresql,
#[cfg(sqlite)]
Sqlite,
}
pub static ACTIVE_DB_TYPE: OnceLock<DbConnType> = OnceLock::new();
pub struct DbConn { pub struct DbConn {
conn: Arc<Mutex<Option<DbConnInner>>>, conn: Arc<Mutex<Option<PooledConnection<ConnectionManager<DbConnInner>>>>>,
permit: Option<OwnedSemaphorePermit>, permit: Option<OwnedSemaphorePermit>,
} }
#[allow(non_camel_case_types)]
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
#[derive(Debug)] #[derive(Debug)]
pub struct DbConnOptions { pub struct DbConnOptions {
pub init_stmts: String, pub init_stmts: String,
} }
$( // Based on <https://stackoverflow.com/a/57717533>. impl CustomizeConnection<DbConnInner, diesel::r2d2::Error> for DbConnOptions {
#[cfg($name)] fn on_acquire(&self, conn: &mut DbConnInner) -> Result<(), diesel::r2d2::Error> {
impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> {
if !self.init_stmts.is_empty() { if !self.init_stmts.is_empty() {
conn.batch_execute(&self.init_stmts).map_err(diesel::r2d2::Error::QueryError)?; conn.batch_execute(&self.init_stmts).map_err(diesel::r2d2::Error::QueryError)?;
} }
Ok(()) Ok(())
} }
})+ }
#[derive(Clone)] #[derive(Clone)]
pub struct DbPool { pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'. // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<DbPoolInner>, pool: Option<Pool<ConnectionManager<DbConnInner>>>,
semaphore: Arc<Semaphore> semaphore: Arc<Semaphore>,
} }
#[allow(non_camel_case_types)]
#[derive(Clone)]
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
impl Drop for DbConn { impl Drop for DbConn {
fn drop(&mut self) { fn drop(&mut self) {
let conn = Arc::clone(&self.conn); let conn = Arc::clone(&self.conn);
@ -116,7 +116,11 @@ macro_rules! generate_connections {
impl Drop for DbPool { impl Drop for DbPool {
fn drop(&mut self) { fn drop(&mut self) {
let pool = self.pool.take(); let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool)); // Only use spawn_blocking if the Tokio runtime is still available
// Otherwise the pool will be dropped on the current thread
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn_blocking(move || drop(pool));
}
} }
} }
@ -126,32 +130,53 @@ macro_rules! generate_connections {
let url = CONFIG.database_url(); let url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&url)?; let conn_type = DbConnType::from_url(&url)?;
match conn_type { $( // Only set the default instrumentation if the log level is specifically set to either warn, info or debug
DbConnType::$name => { if log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Warn)
#[cfg($name)] || log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Info)
|| log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Debug)
{ {
pastey::paste!{ [< $name _migrations >]::run_migrations()?; } drop(diesel::connection::set_default_instrumentation(query_logger::simple_logger));
let manager = ConnectionManager::new(&url); }
match conn_type {
#[cfg(mysql)]
DbConnType::Mysql => {
mysql_migrations::run_migrations(&url)?;
}
#[cfg(postgresql)]
DbConnType::Postgresql => {
postgresql_migrations::run_migrations(&url)?;
}
#[cfg(sqlite)]
DbConnType::Sqlite => {
sqlite_migrations::run_migrations(&url)?;
}
}
let max_conns = CONFIG.database_max_conns();
let manager = ConnectionManager::<DbConnInner>::new(&url);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(CONFIG.database_max_conns()) .max_size(max_conns)
.min_idle(Some(CONFIG.database_min_conns())) .min_idle(Some(CONFIG.database_min_conns()))
.idle_timeout(Some(Duration::from_secs(CONFIG.database_idle_timeout()))) .idle_timeout(Some(Duration::from_secs(CONFIG.database_idle_timeout())))
.connection_timeout(Duration::from_secs(CONFIG.database_timeout())) .connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
.connection_customizer(Box::new(DbConnOptions { .connection_customizer(Box::new(DbConnOptions {
init_stmts: conn_type.get_init_stmts() init_stmts: conn_type.get_init_stmts(),
})) }))
.build(manager) .build(manager)
.map_res("Failed to create pool")?; .map_res("Failed to create pool")?;
// Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.")
}
Ok(DbPool { Ok(DbPool {
pool: Some(DbPoolInner::$name(pool)), pool: Some(pool),
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)), semaphore: Arc::new(Semaphore::new(max_conns as usize)),
}) })
} }
#[cfg(not($name))]
unreachable!("Trying to use a DB backend when it's feature is disabled")
},
)+ }
}
// Get a connection from the pool // Get a connection from the pool
pub async fn get(&self) -> Result<DbConn, Error> { pub async fn get(&self) -> Result<DbConn, Error> {
let duration = Duration::from_secs(CONFIG.database_timeout()); let duration = Duration::from_secs(CONFIG.database_timeout());
@ -162,51 +187,31 @@ macro_rules! generate_connections {
} }
}; };
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $( let p = self.pool.as_ref().expect("DbPool.pool should always be Some()");
#[cfg($name)]
DbPoolInner::$name(p) => {
let pool = p.clone(); let pool = p.clone();
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?; let c =
run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
Ok(DbConn { Ok(DbConn {
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))), conn: Arc::new(Mutex::new(Some(c))),
permit: Some(permit) permit: Some(permit),
}) })
},
)+ }
}
}
};
}
#[cfg(not(query_logger))]
generate_connections! {
sqlite: diesel::sqlite::SqliteConnection,
mysql: diesel::mysql::MysqlConnection,
postgresql: diesel::pg::PgConnection
} }
#[cfg(query_logger)]
generate_connections! {
sqlite: diesel_logger::LoggingConnection<diesel::sqlite::SqliteConnection>,
mysql: diesel_logger::LoggingConnection<diesel::mysql::MysqlConnection>,
postgresql: diesel_logger::LoggingConnection<diesel::pg::PgConnection>
} }
impl DbConnType { impl DbConnType {
pub fn from_url(url: &str) -> Result<DbConnType, Error> { pub fn from_url(url: &str) -> Result<Self, Error> {
// Mysql // Mysql
if url.starts_with("mysql:") { if url.len() > 6 && &url[..6] == "mysql:" {
#[cfg(mysql)] #[cfg(mysql)]
return Ok(DbConnType::mysql); return Ok(DbConnType::Mysql);
#[cfg(not(mysql))] #[cfg(not(mysql))]
err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled") err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled")
// Postgres // Postgresql
} else if url.starts_with("postgresql:") || url.starts_with("postgres:") { } else if url.len() > 11 && (&url[..11] == "postgresql:" || &url[..9] == "postgres:") {
#[cfg(postgresql)] #[cfg(postgresql)]
return Ok(DbConnType::postgresql); return Ok(DbConnType::Postgresql);
#[cfg(not(postgresql))] #[cfg(not(postgresql))]
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled") err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
@ -214,7 +219,7 @@ impl DbConnType {
//Sqlite //Sqlite
} else { } else {
#[cfg(sqlite)] #[cfg(sqlite)]
return Ok(DbConnType::sqlite); return Ok(DbConnType::Sqlite);
#[cfg(not(sqlite))] #[cfg(not(sqlite))]
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled") err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
@ -232,175 +237,102 @@ impl DbConnType {
pub fn default_init_stmts(&self) -> String { pub fn default_init_stmts(&self) -> String {
match self { match self {
Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(), #[cfg(mysql)]
Self::mysql => String::new(), Self::Mysql => String::new(),
Self::postgresql => String::new(), #[cfg(postgresql)]
Self::Postgresql => String::new(),
#[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
} }
} }
} }
#[macro_export] #[macro_export]
macro_rules! db_run { macro_rules! db_run_base {
// Same for all dbs ( $conn:ident ) => {
( $conn:ident: $body:block ) => { let conn = std::sync::Arc::clone(&$conn.conn);
db_run! { $conn: sqlite, mysql, postgresql $body }
};
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use $crate::db::FromDb;
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await; let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") { let $conn = conn.as_mut().expect("internal invariant broken: self.conn is Some");
$($( };
#[cfg($db)]
$crate::db::DbConnInner::$db($conn) => {
pastey::paste! {
#[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*;
} }
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead #[macro_export]
}, macro_rules! db_run {
)+)+ ( $conn:ident: $body:block ) => {{
} db_run_base!($conn);
// Run blocking can't be used due to the 'static limitation, use block_in_place instead
tokio::task::block_in_place(move || $body )
}}; }};
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{ ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*; db_run_base!($conn);
#[allow(unused)] use $crate::db::FromDb; match std::ops::DerefMut::deref_mut($conn) {
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
$($( $($(
#[cfg($db)] #[cfg($db)]
$crate::db::DbConnInner::$db($conn) => { pastey::paste!(&mut $crate::db::DbConnInner::[<$db:camel>](ref mut $conn)) => {
pastey::paste! { // Run blocking can't be used due to the 'static limitation, use block_in_place instead
#[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *}; tokio::task::block_in_place(move || $body )
// @ RAW: #[allow(unused)] use [<__ $db _model>]::*;
}
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
}, },
)+)+ )+)+}
}
}}; }};
} }
pub trait FromDb { pub mod schema;
type Output;
#[allow(clippy::wrong_self_convention)]
fn from_db(self) -> Self::Output;
}
impl<T: FromDb> FromDb for Vec<T> {
type Output = Vec<T::Output>;
#[inline(always)]
fn from_db(self) -> Self::Output {
self.into_iter().map(FromDb::from_db).collect()
}
}
impl<T: FromDb> FromDb for Option<T> {
type Output = Option<T::Output>;
#[inline(always)]
fn from_db(self) -> Self::Output {
self.map(FromDb::from_db)
}
}
// For each struct eg. Cipher, we create a CipherDb inside a module named __$db_model (where $db is sqlite, mysql or postgresql),
// to implement the Diesel traits. We also provide methods to convert between them and the basic structs. Later, that module will be auto imported when using db_run!
#[macro_export]
macro_rules! db_object {
( $(
$( #[$attr:meta] )*
pub struct $name:ident {
$( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty ),+
$(,)?
}
)+ ) => {
// Create the normal struct, without attributes
$( pub struct $name { $( /*$( #[$field_attr] )**/ $vis $field : $typ, )+ } )+
#[cfg(sqlite)]
pub mod __sqlite_model { $( db_object! { @db sqlite | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
#[cfg(mysql)]
pub mod __mysql_model { $( db_object! { @db mysql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
#[cfg(postgresql)]
pub mod __postgresql_model { $( db_object! { @db postgresql | $( #[$attr] )* | $name | $( $( #[$field_attr] )* $field : $typ ),+ } )+ }
};
( @db $db:ident | $( #[$attr:meta] )* | $name:ident | $( $( #[$field_attr:meta] )* $vis:vis $field:ident : $typ:ty),+) => {
pastey::paste! {
#[allow(unused)] use super::*;
#[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use $crate::db::[<__ $db _schema>]::*;
$( #[$attr] )*
pub struct [<$name Db>] { $(
$( #[$field_attr] )* $vis $field : $typ,
)+ }
impl [<$name Db>] {
#[allow(clippy::wrong_self_convention)]
#[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
}
impl $crate::db::FromDb for [<$name Db>] {
type Output = super::$name;
#[allow(clippy::wrong_self_convention)]
#[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
}
}
};
}
// Reexport the models, needs to be after the macros are defined so it can access them // Reexport the models, needs to be after the macros are defined so it can access them
pub mod models; pub mod models;
/// Creates a back-up of the sqlite database /// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported. /// MySQL/MariaDB and PostgreSQL are not supported.
pub async fn backup_database(conn: &mut DbConn) -> Result<String, Error> { #[cfg(sqlite)]
db_run! {@raw conn: pub fn backup_sqlite() -> Result<String, Error> {
postgresql, mysql { use diesel::Connection;
let _ = conn; use std::{fs::File, io::Write};
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
}
sqlite {
let db_url = CONFIG.database_url(); let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) {
// Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't
// This way we can set a readonly flag on the opening mode without issues.
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?;
let db_path = std::path::Path::new(&db_url).parent().unwrap(); let db_path = std::path::Path::new(&db_url).parent().unwrap();
let backup_file = db_path let backup_file = db_path
.join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S"))) .join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S")))
.to_string_lossy() .to_string_lossy()
.into_owned(); .into_owned();
diesel::sql_query(format!("VACUUM INTO '{backup_file}'")).execute(conn)?;
match File::create(backup_file.clone()) {
Ok(mut f) => {
let serialized_db = conn.serialize_database_to_buffer();
f.write_all(serialized_db.as_slice()).expect("Error writing SQLite backup");
Ok(backup_file) Ok(backup_file)
} }
Err(e) => {
err_silent!(format!("Unable to save SQLite backup: {e:?}"))
}
}
} else {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
} }
#[cfg(not(sqlite))]
pub fn backup_sqlite() -> Result<String, Error> {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
} }
/// Get the SQL Server version /// Get the SQL Server version
pub async fn get_sql_server_version(conn: &mut DbConn) -> String { pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! {@raw conn: db_run! { conn:
postgresql,mysql { postgresql,mysql {
define_sql_function!{ diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
fn version() -> diesel::sql_types::Text; .get_result::<String>(conn)
} .unwrap_or_else(|_| "Unknown".to_string())
diesel::select(version()).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
} }
sqlite { sqlite {
define_sql_function!{ diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
fn sqlite_version() -> diesel::sql_types::Text; .get_result::<String>(conn)
} .unwrap_or_else(|_| "Unknown".to_string())
diesel::select(sqlite_version()).get_result::<String>(conn).unwrap_or_else(|_| "Unknown".to_string())
} }
} }
} }
@ -428,16 +360,14 @@ impl<'r> FromRequest<'r> for DbConn {
// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html // https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html
#[cfg(sqlite)] #[cfg(sqlite)]
mod sqlite_migrations { mod sqlite_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite"); pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
pub fn run_migrations() -> Result<(), super::Error> { pub fn run_migrations(url: &str) -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
let url = crate::CONFIG.database_url();
// Establish a connection to the sqlite database (this will create a new one, if it does // Establish a connection to the sqlite database (this will create a new one, if it does
// not exist, and exit if there is an error). // not exist, and exit if there is an error).
let mut connection = diesel::sqlite::SqliteConnection::establish(&url)?; let mut connection = diesel::sqlite::SqliteConnection::establish(url)?;
// Run the migrations after successfully establishing a connection // Run the migrations after successfully establishing a connection
// Disable Foreign Key Checks during migration // Disable Foreign Key Checks during migration
@ -458,15 +388,15 @@ mod sqlite_migrations {
#[cfg(mysql)] #[cfg(mysql)]
mod mysql_migrations { mod mysql_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql"); pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql");
pub fn run_migrations() -> Result<(), super::Error> { pub fn run_migrations(url: &str) -> Result<(), super::Error> {
use diesel::{Connection, RunQueryDsl};
// Make sure the database is up to date (create if it doesn't exist, or run the migrations) // Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::mysql::MysqlConnection::establish(&crate::CONFIG.database_url())?; let mut connection = diesel::mysql::MysqlConnection::establish(url)?;
// Disable Foreign Key Checks during migration
// Disable Foreign Key Checks during migration
// Scoped to a connection/session. // Scoped to a connection/session.
diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0") diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0")
.execute(&mut connection) .execute(&mut connection)
@ -479,13 +409,14 @@ mod mysql_migrations {
#[cfg(postgresql)] #[cfg(postgresql)]
mod postgresql_migrations { mod postgresql_migrations {
use diesel::Connection;
use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql"); pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql");
pub fn run_migrations() -> Result<(), super::Error> { pub fn run_migrations(url: &str) -> Result<(), super::Error> {
use diesel::Connection;
// Make sure the database is up to date (create if it doesn't exist, or run the migrations) // Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::pg::PgConnection::establish(&crate::CONFIG.database_url())?; let mut connection = diesel::pg::PgConnection::establish(url)?;
connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations");
Ok(()) Ok(())
} }

50
src/db/models/attachment.rs

@ -1,14 +1,14 @@
use std::time::Duration;
use bigdecimal::{BigDecimal, ToPrimitive}; use bigdecimal::{BigDecimal, ToPrimitive};
use derive_more::{AsRef, Deref, Display}; use derive_more::{AsRef, Deref, Display};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use std::time::Duration;
use super::{CipherId, OrganizationId, UserId}; use super::{CipherId, OrganizationId, UserId};
use crate::db::schema::{attachments, ciphers};
use crate::{config::PathType, CONFIG}; use crate::{config::PathType, CONFIG};
use macros::IdFromParam; use macros::IdFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = attachments)] #[diesel(table_name = attachments)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -20,7 +20,6 @@ db_object! {
pub file_size: i64, pub file_size: i64,
pub akey: Option<String>, pub akey: Option<String>,
} }
}
/// Local methods /// Local methods
impl Attachment { impl Attachment {
@ -76,11 +75,11 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Attachment { impl Attachment {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(attachments::table) match diesel::replace_into(attachments::table)
.values(AttachmentDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -88,7 +87,7 @@ impl Attachment {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(attachments::table) diesel::update(attachments::table)
.filter(attachments::id.eq(&self.id)) .filter(attachments::id.eq(&self.id))
.set(AttachmentDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving attachment") .map_res("Error saving attachment")
} }
@ -96,22 +95,22 @@ impl Attachment {
}.map_res("Error saving attachment") }.map_res("Error saving attachment")
} }
postgresql { postgresql {
let value = AttachmentDb::to_db(self);
diesel::insert_into(attachments::table) diesel::insert_into(attachments::table)
.values(&value) .values(self)
.on_conflict(attachments::id) .on_conflict(attachments::id)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving attachment") .map_res("Error saving attachment")
} }
} }
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
crate::util::retry( crate::util::retry(||
|| diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn), diesel::delete(attachments::table.filter(attachments::id.eq(&self.id)))
.execute(conn),
10, 10,
) )
.map(|_| ()) .map(|_| ())
@ -132,34 +131,32 @@ impl Attachment {
Ok(()) Ok(())
} }
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
for attachment in Attachment::find_by_cipher(cipher_uuid, conn).await { for attachment in Attachment::find_by_cipher(cipher_uuid, conn).await {
attachment.delete(conn).await?; attachment.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn find_by_id(id: &AttachmentId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.filter(attachments::id.eq(id.to_lowercase())) .filter(attachments::id.eq(id.to_lowercase()))
.first::<AttachmentDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid)) .filter(attachments::cipher_uuid.eq(cipher_uuid))
.load::<AttachmentDb>(conn) .load::<Self>(conn)
.expect("Error loading attachments") .expect("Error loading attachments")
.from_db()
}} }}
} }
pub async fn size_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 { pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
let result: Option<BigDecimal> = attachments::table let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -176,7 +173,7 @@ impl Attachment {
}} }}
} }
pub async fn count_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 { pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -187,7 +184,7 @@ impl Attachment {
}} }}
} }
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
let result: Option<BigDecimal> = attachments::table let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -204,7 +201,7 @@ impl Attachment {
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -221,7 +218,7 @@ impl Attachment {
pub async fn find_all_by_user_and_orgs( pub async fn find_all_by_user_and_orgs(
user_uuid: &UserId, user_uuid: &UserId,
org_uuids: &Vec<OrganizationId>, org_uuids: &Vec<OrganizationId>,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
@ -229,9 +226,8 @@ impl Attachment {
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
.or_filter(ciphers::organization_uuid.eq_any(org_uuids)) .or_filter(ciphers::organization_uuid.eq_any(org_uuids))
.select(attachments::all_columns) .select(attachments::all_columns)
.load::<AttachmentDb>(conn) .load::<Self>(conn)
.expect("Error loading attachments") .expect("Error loading attachments")
.from_db()
}} }}
} }
} }

44
src/db/models/auth_request.rs

@ -1,11 +1,12 @@
use super::{DeviceId, OrganizationId, UserId}; use super::{DeviceId, OrganizationId, UserId};
use crate::db::schema::auth_requests;
use crate::{crypto::ct_eq, util::format_date}; use crate::{crypto::ct_eq, util::format_date};
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam; use macros::UuidFromParam;
use serde_json::Value; use serde_json::Value;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)] #[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)]
#[diesel(table_name = auth_requests)] #[diesel(table_name = auth_requests)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -33,7 +34,6 @@ db_object! {
pub authentication_date: Option<NaiveDateTime>, pub authentication_date: Option<NaiveDateTime>,
} }
}
impl AuthRequest { impl AuthRequest {
pub fn new( pub fn new(
@ -80,11 +80,11 @@ use crate::api::EmptyResult;
use crate::error::MapResult; use crate::error::MapResult;
impl AuthRequest { impl AuthRequest {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(auth_requests::table) match diesel::replace_into(auth_requests::table)
.values(AuthRequestDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -92,7 +92,7 @@ impl AuthRequest {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(auth_requests::table) diesel::update(auth_requests::table)
.filter(auth_requests::uuid.eq(&self.uuid)) .filter(auth_requests::uuid.eq(&self.uuid))
.set(AuthRequestDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error auth_request") .map_res("Error auth_request")
} }
@ -100,51 +100,49 @@ impl AuthRequest {
}.map_res("Error auth_request") }.map_res("Error auth_request")
} }
postgresql { postgresql {
let value = AuthRequestDb::to_db(self);
diesel::insert_into(auth_requests::table) diesel::insert_into(auth_requests::table)
.values(&value) .values(&*self)
.on_conflict(auth_requests::uuid) .on_conflict(auth_requests::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving auth_request") .map_res("Error saving auth_request")
} }
} }
} }
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
auth_requests::table auth_requests::table
.filter(auth_requests::uuid.eq(uuid)) .filter(auth_requests::uuid.eq(uuid))
.first::<AuthRequestDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
auth_requests::table auth_requests::table
.filter(auth_requests::uuid.eq(uuid)) .filter(auth_requests::uuid.eq(uuid))
.filter(auth_requests::user_uuid.eq(user_uuid)) .filter(auth_requests::user_uuid.eq(user_uuid))
.first::<AuthRequestDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
auth_requests::table auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid)) .filter(auth_requests::user_uuid.eq(user_uuid))
.load::<AuthRequestDb>(conn).expect("Error loading auth_requests").from_db() .load::<Self>(conn)
.expect("Error loading auth_requests")
}} }}
} }
pub async fn find_by_user_and_requested_device( pub async fn find_by_user_and_requested_device(
user_uuid: &UserId, user_uuid: &UserId,
device_uuid: &DeviceId, device_uuid: &DeviceId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
auth_requests::table auth_requests::table
@ -152,19 +150,21 @@ impl AuthRequest {
.filter(auth_requests::request_device_identifier.eq(device_uuid)) .filter(auth_requests::request_device_identifier.eq(device_uuid))
.filter(auth_requests::approved.is_null()) .filter(auth_requests::approved.is_null())
.order_by(auth_requests::creation_date.desc()) .order_by(auth_requests::creation_date.desc())
.first::<AuthRequestDb>(conn).ok().from_db() .first::<Self>(conn)
.ok()
}} }}
} }
pub async fn find_created_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> { pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
auth_requests::table auth_requests::table
.filter(auth_requests::creation_date.lt(dt)) .filter(auth_requests::creation_date.lt(dt))
.load::<AuthRequestDb>(conn).expect("Error loading auth_requests").from_db() .load::<Self>(conn)
.expect("Error loading auth_requests")
}} }}
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid))) diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid)))
.execute(conn) .execute(conn)
@ -176,7 +176,7 @@ impl AuthRequest {
ct_eq(&self.access_code, access_code) ct_eq(&self.access_code, access_code)
} }
pub async fn purge_expired_auth_requests(conn: &mut DbConn) { pub async fn purge_expired_auth_requests(conn: &DbConn) {
let expiry_time = Utc::now().naive_utc() - chrono::TimeDelta::try_minutes(5).unwrap(); //after 5 minutes, clients reject the request let expiry_time = Utc::now().naive_utc() - chrono::TimeDelta::try_minutes(5).unwrap(); //after 5 minutes, clients reject the request
for auth_request in Self::find_created_before(&expiry_time, conn).await { for auth_request in Self::find_created_before(&expiry_time, conn).await {
auth_request.delete(conn).await.ok(); auth_request.delete(conn).await.ok();

133
src/db/models/cipher.rs

@ -1,7 +1,12 @@
use crate::db::schema::{
ciphers, ciphers_collections, collections, collections_groups, folders, folders_ciphers, groups, groups_users,
users_collections, users_organizations,
};
use crate::util::LowerCase; use crate::util::LowerCase;
use crate::CONFIG; use crate::CONFIG;
use chrono::{NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{ use super::{
@ -13,7 +18,6 @@ use macros::UuidFromParam;
use std::borrow::Cow; use std::borrow::Cow;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = ciphers)] #[diesel(table_name = ciphers)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -46,7 +50,6 @@ db_object! {
pub deleted_at: Option<NaiveDateTime>, pub deleted_at: Option<NaiveDateTime>,
pub reprompt: Option<i32>, pub reprompt: Option<i32>,
} }
}
pub enum RepromptType { pub enum RepromptType {
None = 0, None = 0,
@ -140,7 +143,7 @@ impl Cipher {
user_uuid: &UserId, user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>, cipher_sync_data: Option<&CipherSyncData>,
sync_type: CipherSyncType, sync_type: CipherSyncType,
conn: &mut DbConn, conn: &DbConn,
) -> Result<Value, crate::Error> { ) -> Result<Value, crate::Error> {
use crate::util::{format_date, validate_and_format_date}; use crate::util::{format_date, validate_and_format_date};
@ -402,7 +405,7 @@ impl Cipher {
Ok(json_object) Ok(json_object)
} }
pub async fn update_users_revision(&self, conn: &mut DbConn) -> Vec<UserId> { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new(); let mut user_uuids = Vec::new();
match self.user_uuid { match self.user_uuid {
Some(ref user_uuid) => { Some(ref user_uuid) => {
@ -430,14 +433,14 @@ impl Cipher {
user_uuids user_uuids
} }
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(ciphers::table) match diesel::replace_into(ciphers::table)
.values(CipherDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -445,7 +448,7 @@ impl Cipher {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(ciphers::table) diesel::update(ciphers::table)
.filter(ciphers::uuid.eq(&self.uuid)) .filter(ciphers::uuid.eq(&self.uuid))
.set(CipherDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving cipher") .map_res("Error saving cipher")
} }
@ -453,19 +456,18 @@ impl Cipher {
}.map_res("Error saving cipher") }.map_res("Error saving cipher")
} }
postgresql { postgresql {
let value = CipherDb::to_db(self);
diesel::insert_into(ciphers::table) diesel::insert_into(ciphers::table)
.values(&value) .values(&*self)
.on_conflict(ciphers::uuid) .on_conflict(ciphers::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving cipher") .map_res("Error saving cipher")
} }
} }
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
FolderCipher::delete_all_by_cipher(&self.uuid, conn).await?; FolderCipher::delete_all_by_cipher(&self.uuid, conn).await?;
@ -480,7 +482,7 @@ impl Cipher {
}} }}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
// TODO: Optimize this by executing a DELETE directly on the database, instead of first fetching. // TODO: Optimize this by executing a DELETE directly on the database, instead of first fetching.
for cipher in Self::find_by_org(org_uuid, conn).await { for cipher in Self::find_by_org(org_uuid, conn).await {
cipher.delete(conn).await?; cipher.delete(conn).await?;
@ -488,7 +490,7 @@ impl Cipher {
Ok(()) Ok(())
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for cipher in Self::find_owned_by_user(user_uuid, conn).await { for cipher in Self::find_owned_by_user(user_uuid, conn).await {
cipher.delete(conn).await?; cipher.delete(conn).await?;
} }
@ -496,7 +498,7 @@ impl Cipher {
} }
/// Purge all ciphers that are old enough to be auto-deleted. /// Purge all ciphers that are old enough to be auto-deleted.
pub async fn purge_trash(conn: &mut DbConn) { pub async fn purge_trash(conn: &DbConn) {
if let Some(auto_delete_days) = CONFIG.trash_auto_delete_days() { if let Some(auto_delete_days) = CONFIG.trash_auto_delete_days() {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
let dt = now - TimeDelta::try_days(auto_delete_days).unwrap(); let dt = now - TimeDelta::try_days(auto_delete_days).unwrap();
@ -510,7 +512,7 @@ impl Cipher {
&self, &self,
folder_uuid: Option<FolderId>, folder_uuid: Option<FolderId>,
user_uuid: &UserId, user_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
@ -550,7 +552,7 @@ impl Cipher {
&self, &self,
user_uuid: &UserId, user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>, cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn, conn: &DbConn,
) -> bool { ) -> bool {
if let Some(ref org_uuid) = self.organization_uuid { if let Some(ref org_uuid) = self.organization_uuid {
if let Some(cipher_sync_data) = cipher_sync_data { if let Some(cipher_sync_data) = cipher_sync_data {
@ -569,7 +571,7 @@ impl Cipher {
&self, &self,
user_uuid: &UserId, user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>, cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn, conn: &DbConn,
) -> bool { ) -> bool {
if !CONFIG.org_groups_enabled() { if !CONFIG.org_groups_enabled() {
return false; return false;
@ -593,7 +595,7 @@ impl Cipher {
&self, &self,
user_uuid: &UserId, user_uuid: &UserId,
cipher_sync_data: Option<&CipherSyncData>, cipher_sync_data: Option<&CipherSyncData>,
conn: &mut DbConn, conn: &DbConn,
) -> Option<(bool, bool, bool)> { ) -> Option<(bool, bool, bool)> {
// Check whether this cipher is directly owned by the user, or is in // Check whether this cipher is directly owned by the user, or is in
// a collection that the user has full access to. If so, there are no // a collection that the user has full access to. If so, there are no
@ -659,11 +661,7 @@ impl Cipher {
Some((read_only, hide_passwords, manage)) Some((read_only, hide_passwords, manage))
} }
async fn get_user_collections_access_flags( async fn get_user_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
&self,
user_uuid: &UserId,
conn: &mut DbConn,
) -> Vec<(bool, bool, bool)> {
db_run! { conn: { db_run! { conn: {
// Check whether this cipher is in any collections accessible to the // Check whether this cipher is in any collections accessible to the
// user. If so, retrieve the access flags for each collection. // user. If so, retrieve the access flags for each collection.
@ -680,11 +678,7 @@ impl Cipher {
}} }}
} }
async fn get_group_collections_access_flags( async fn get_group_collections_access_flags(&self, user_uuid: &UserId, conn: &DbConn) -> Vec<(bool, bool, bool)> {
&self,
user_uuid: &UserId,
conn: &mut DbConn,
) -> Vec<(bool, bool, bool)> {
if !CONFIG.org_groups_enabled() { if !CONFIG.org_groups_enabled() {
return Vec::new(); return Vec::new();
} }
@ -710,31 +704,31 @@ impl Cipher {
}} }}
} }
pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_write_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
match self.get_access_restrictions(user_uuid, None, conn).await { match self.get_access_restrictions(user_uuid, None, conn).await {
Some((read_only, _hide_passwords, manage)) => !read_only || manage, Some((read_only, _hide_passwords, manage)) => !read_only || manage,
None => false, None => false,
} }
} }
pub async fn is_accessible_to_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_accessible_to_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
self.get_access_restrictions(user_uuid, None, conn).await.is_some() self.get_access_restrictions(user_uuid, None, conn).await.is_some()
} }
// Returns whether this cipher is a favorite of the specified user. // Returns whether this cipher is a favorite of the specified user.
pub async fn is_favorite(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_favorite(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
Favorite::is_favorite(&self.uuid, user_uuid, conn).await Favorite::is_favorite(&self.uuid, user_uuid, conn).await
} }
// Sets whether this cipher is a favorite of the specified user. // Sets whether this cipher is a favorite of the specified user.
pub async fn set_favorite(&self, favorite: Option<bool>, user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn set_favorite(&self, favorite: Option<bool>, user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
match favorite { match favorite {
None => Ok(()), // No change requested. None => Ok(()), // No change requested.
Some(status) => Favorite::set_favorite(status, &self.uuid, user_uuid, conn).await, Some(status) => Favorite::set_favorite(status, &self.uuid, user_uuid, conn).await,
} }
} }
pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &mut DbConn) -> Option<FolderId> { pub async fn get_folder_uuid(&self, user_uuid: &UserId, conn: &DbConn) -> Option<FolderId> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.inner_join(folders::table) .inner_join(folders::table)
@ -746,28 +740,26 @@ impl Cipher {
}} }}
} }
pub async fn find_by_uuid(uuid: &CipherId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &CipherId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(uuid)) .filter(ciphers::uuid.eq(uuid))
.first::<CipherDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_uuid_and_org( pub async fn find_by_uuid_and_org(
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(cipher_uuid)) .filter(ciphers::uuid.eq(cipher_uuid))
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
.first::<CipherDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
@ -787,7 +779,7 @@ impl Cipher {
user_uuid: &UserId, user_uuid: &UserId,
visible_only: bool, visible_only: bool,
cipher_uuids: &Vec<CipherId>, cipher_uuids: &Vec<CipherId>,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
@ -839,7 +831,8 @@ impl Cipher {
query query
.select(ciphers::all_columns) .select(ciphers::all_columns)
.distinct() .distinct()
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} else { } else {
db_run! { conn: { db_run! { conn: {
@ -878,45 +871,43 @@ impl Cipher {
query query
.select(ciphers::all_columns) .select(ciphers::all_columns)
.distinct() .distinct()
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} }
} }
// Find all ciphers visible to the specified user. // Find all ciphers visible to the specified user.
pub async fn find_by_user_visible(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user_visible(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
Self::find_by_user(user_uuid, true, &vec![], conn).await Self::find_by_user(user_uuid, true, &vec![], conn).await
} }
pub async fn find_by_user_and_ciphers( pub async fn find_by_user_and_ciphers(
user_uuid: &UserId, user_uuid: &UserId,
cipher_uuids: &Vec<CipherId>, cipher_uuids: &Vec<CipherId>,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
Self::find_by_user(user_uuid, true, cipher_uuids, conn).await Self::find_by_user(user_uuid, true, cipher_uuids, conn).await
} }
pub async fn find_by_user_and_cipher( pub async fn find_by_user_and_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> Option<Self> {
user_uuid: &UserId,
cipher_uuid: &CipherId,
conn: &mut DbConn,
) -> Option<Self> {
Self::find_by_user(user_uuid, true, &vec![cipher_uuid.clone()], conn).await.pop() Self::find_by_user(user_uuid, true, &vec![cipher_uuid.clone()], conn).await.pop()
} }
// Find all ciphers directly owned by the specified user. // Find all ciphers directly owned by the specified user.
pub async fn find_owned_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter( .filter(
ciphers::user_uuid.eq(user_uuid) ciphers::user_uuid.eq(user_uuid)
.and(ciphers::organization_uuid.is_null()) .and(ciphers::organization_uuid.is_null())
) )
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} }
pub async fn count_owned_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 { pub async fn count_owned_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
@ -927,15 +918,16 @@ impl Cipher {
}} }}
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
@ -946,25 +938,27 @@ impl Cipher {
}} }}
} }
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table.inner_join(ciphers::table) folders_ciphers::table.inner_join(ciphers::table)
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.select(ciphers::all_columns) .select(ciphers::all_columns)
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} }
/// Find all ciphers that were deleted before the specified datetime. /// Find all ciphers that were deleted before the specified datetime.
pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> { pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
ciphers::table ciphers::table
.filter(ciphers::deleted_at.lt(dt)) .filter(ciphers::deleted_at.lt(dt))
.load::<CipherDb>(conn).expect("Error loading ciphers").from_db() .load::<Self>(conn)
.expect("Error loading ciphers")
}} }}
} }
pub async fn get_collections(&self, user_uuid: UserId, conn: &mut DbConn) -> Vec<CollectionId> { pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
ciphers_collections::table ciphers_collections::table
@ -996,7 +990,8 @@ impl Cipher {
.and(collections_groups::read_only.eq(false))) .and(collections_groups::read_only.eq(false)))
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default() .load::<CollectionId>(conn)
.unwrap_or_default()
}} }}
} else { } else {
db_run! { conn: { db_run! { conn: {
@ -1018,12 +1013,13 @@ impl Cipher {
.and(users_collections::read_only.eq(false))) .and(users_collections::read_only.eq(false)))
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default() .load::<CollectionId>(conn)
.unwrap_or_default()
}} }}
} }
} }
pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &mut DbConn) -> Vec<CollectionId> { pub async fn get_admin_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
ciphers_collections::table ciphers_collections::table
@ -1056,7 +1052,8 @@ impl Cipher {
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner .or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default() .load::<CollectionId>(conn)
.unwrap_or_default()
}} }}
} else { } else {
db_run! { conn: { db_run! { conn: {
@ -1079,7 +1076,8 @@ impl Cipher {
.or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner .or(users_organizations::atype.le(MembershipType::Admin as i32)) // User is admin or owner
) )
.select(ciphers_collections::collection_uuid) .select(ciphers_collections::collection_uuid)
.load::<CollectionId>(conn).unwrap_or_default() .load::<CollectionId>(conn)
.unwrap_or_default()
}} }}
} }
} }
@ -1088,7 +1086,7 @@ impl Cipher {
/// This is used during a full sync so we only need one query for all collections accessible. /// This is used during a full sync so we only need one query for all collections accessible.
pub async fn get_collections_with_cipher_by_user( pub async fn get_collections_with_cipher_by_user(
user_uuid: UserId, user_uuid: UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<(CipherId, CollectionId)> { ) -> Vec<(CipherId, CollectionId)> {
db_run! { conn: { db_run! { conn: {
ciphers_collections::table ciphers_collections::table
@ -1123,7 +1121,8 @@ impl Cipher {
.or_filter(collections_groups::collections_uuid.is_not_null()) //Access via group .or_filter(collections_groups::collections_uuid.is_not_null()) //Access via group
.select(ciphers_collections::all_columns) .select(ciphers_collections::all_columns)
.distinct() .distinct()
.load::<(CipherId, CollectionId)>(conn).unwrap_or_default() .load::<(CipherId, CollectionId)>(conn)
.unwrap_or_default()
}} }}
} }
} }

128
src/db/models/collection.rs

@ -5,10 +5,13 @@ use super::{
CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId,
User, UserId, User, UserId,
}; };
use crate::db::schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations,
};
use crate::CONFIG; use crate::CONFIG;
use diesel::prelude::*;
use macros::UuidFromParam; use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = collections)] #[diesel(table_name = collections)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -38,7 +41,6 @@ db_object! {
pub cipher_uuid: CipherId, pub cipher_uuid: CipherId,
pub collection_uuid: CollectionId, pub collection_uuid: CollectionId,
} }
}
/// Local methods /// Local methods
impl Collection { impl Collection {
@ -83,7 +85,7 @@ impl Collection {
&self, &self,
user_uuid: &UserId, user_uuid: &UserId,
cipher_sync_data: Option<&crate::api::core::CipherSyncData>, cipher_sync_data: Option<&crate::api::core::CipherSyncData>,
conn: &mut DbConn, conn: &DbConn,
) -> Value { ) -> Value {
let (read_only, hide_passwords, manage) = if let Some(cipher_sync_data) = cipher_sync_data { let (read_only, hide_passwords, manage) = if let Some(cipher_sync_data) = cipher_sync_data {
match cipher_sync_data.members.get(&self.org_uuid) { match cipher_sync_data.members.get(&self.org_uuid) {
@ -135,7 +137,7 @@ impl Collection {
json_object json_object
} }
pub async fn can_access_collection(member: &Membership, col_id: &CollectionId, conn: &mut DbConn) -> bool { pub async fn can_access_collection(member: &Membership, col_id: &CollectionId, conn: &DbConn) -> bool {
member.has_status(MembershipStatus::Confirmed) member.has_status(MembershipStatus::Confirmed)
&& (member.has_full_access() && (member.has_full_access()
|| CollectionUser::has_access_to_collection_by_user(col_id, &member.user_uuid, conn).await || CollectionUser::has_access_to_collection_by_user(col_id, &member.user_uuid, conn).await
@ -152,13 +154,13 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Collection { impl Collection {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(collections::table) match diesel::replace_into(collections::table)
.values(CollectionDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -166,7 +168,7 @@ impl Collection {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(collections::table) diesel::update(collections::table)
.filter(collections::uuid.eq(&self.uuid)) .filter(collections::uuid.eq(&self.uuid))
.set(CollectionDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving collection") .map_res("Error saving collection")
} }
@ -174,19 +176,18 @@ impl Collection {
}.map_res("Error saving collection") }.map_res("Error saving collection")
} }
postgresql { postgresql {
let value = CollectionDb::to_db(self);
diesel::insert_into(collections::table) diesel::insert_into(collections::table)
.values(&value) .values(self)
.on_conflict(collections::uuid) .on_conflict(collections::uuid)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving collection") .map_res("Error saving collection")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
CollectionCipher::delete_all_by_collection(&self.uuid, conn).await?; CollectionCipher::delete_all_by_collection(&self.uuid, conn).await?;
CollectionUser::delete_all_by_collection(&self.uuid, conn).await?; CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
@ -199,30 +200,29 @@ impl Collection {
}} }}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for collection in Self::find_by_organization(org_uuid, conn).await { for collection in Self::find_by_organization(org_uuid, conn).await {
collection.delete(conn).await?; collection.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn update_users_revision(&self, conn: &mut DbConn) { pub async fn update_users_revision(&self, conn: &DbConn) {
for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() { for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() {
User::update_uuid_revision(&member.user_uuid, conn).await; User::update_uuid_revision(&member.user_uuid, conn).await;
} }
} }
pub async fn find_by_uuid(uuid: &CollectionId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
.first::<CollectionDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
@ -263,7 +263,8 @@ impl Collection {
) )
.select(collections::all_columns) .select(collections::all_columns)
.distinct() .distinct()
.load::<CollectionDb>(conn).expect("Error loading collections").from_db() .load::<Self>(conn)
.expect("Error loading collections")
}} }}
} else { } else {
db_run! { conn: { db_run! { conn: {
@ -288,7 +289,8 @@ impl Collection {
) )
.select(collections::all_columns) .select(collections::all_columns)
.distinct() .distinct()
.load::<CollectionDb>(conn).expect("Error loading collections").from_db() .load::<Self>(conn)
.expect("Error loading collections")
}} }}
} }
} }
@ -296,7 +298,7 @@ impl Collection {
pub async fn find_by_organization_and_user_uuid( pub async fn find_by_organization_and_user_uuid(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
user_uuid: &UserId, user_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
Self::find_by_user_uuid(user_uuid.to_owned(), conn) Self::find_by_user_uuid(user_uuid.to_owned(), conn)
.await .await
@ -305,17 +307,16 @@ impl Collection {
.collect() .collect()
} }
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.load::<CollectionDb>(conn) .load::<Self>(conn)
.expect("Error loading collections") .expect("Error loading collections")
.from_db()
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
@ -326,23 +327,18 @@ impl Collection {
}} }}
} }
pub async fn find_by_uuid_and_org( pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
uuid: &CollectionId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.select(collections::all_columns) .select(collections::all_columns)
.first::<CollectionDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
@ -380,8 +376,8 @@ impl Collection {
) )
) )
).select(collections::all_columns) ).select(collections::all_columns)
.first::<CollectionDb>(conn).ok() .first::<Self>(conn)
.from_db() .ok()
}} }}
} else { } else {
db_run! { conn: { db_run! { conn: {
@ -403,13 +399,13 @@ impl Collection {
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)) ))
).select(collections::all_columns) ).select(collections::all_columns)
.first::<CollectionDb>(conn).ok() .first::<Self>(conn)
.from_db() .ok()
}} }}
} }
} }
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {
@ -471,7 +467,7 @@ impl Collection {
} }
} }
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
db_run! { conn: { db_run! { conn: {
collections::table collections::table
@ -517,7 +513,7 @@ impl Collection {
}} }}
} }
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string(); let user_uuid = user_uuid.to_string();
db_run! { conn: { db_run! { conn: {
collections::table collections::table
@ -569,7 +565,7 @@ impl CollectionUser {
pub async fn find_by_organization_and_user_uuid( pub async fn find_by_organization_and_user_uuid(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
user_uuid: &UserId, user_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
@ -577,15 +573,14 @@ impl CollectionUser {
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid))) .inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<CollectionUserDb>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
.from_db()
}} }}
} }
pub async fn find_by_organization_swap_user_uuid_with_member_uuid( pub async fn find_by_organization_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<CollectionMembership> { ) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: { let col_users = db_run! { conn: {
users_collections::table users_collections::table
@ -594,9 +589,8 @@ impl CollectionUser {
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid))) .inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage)) .select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<CollectionUserDb>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
.from_db()
}}; }};
col_users.into_iter().map(|c| c.into()).collect() col_users.into_iter().map(|c| c.into()).collect()
} }
@ -607,7 +601,7 @@ impl CollectionUser {
read_only: bool, read_only: bool,
hide_passwords: bool, hide_passwords: bool,
manage: bool, manage: bool,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await; User::update_uuid_revision(user_uuid, conn).await;
@ -664,7 +658,7 @@ impl CollectionUser {
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
@ -678,21 +672,20 @@ impl CollectionUser {
}} }}
} }
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<CollectionUserDb>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
.from_db()
}} }}
} }
pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid( pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
collection_uuid: &CollectionId, collection_uuid: &CollectionId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<CollectionMembership> { ) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: { let col_users = db_run! { conn: {
users_collections::table users_collections::table
@ -700,9 +693,8 @@ impl CollectionUser {
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid))) .inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage)) .select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<CollectionUserDb>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
.from_db()
}}; }};
col_users.into_iter().map(|c| c.into()).collect() col_users.into_iter().map(|c| c.into()).collect()
} }
@ -710,31 +702,29 @@ impl CollectionUser {
pub async fn find_by_collection_and_user( pub async fn find_by_collection_and_user(
collection_uuid: &CollectionId, collection_uuid: &CollectionId,
user_uuid: &UserId, user_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.first::<CollectionUserDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns) .select(users_collections::all_columns)
.load::<CollectionUserDb>(conn) .load::<Self>(conn)
.expect("Error loading users_collections") .expect("Error loading users_collections")
.from_db()
}} }}
} }
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() { for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() {
User::update_uuid_revision(&collection.user_uuid, conn).await; User::update_uuid_revision(&collection.user_uuid, conn).await;
} }
@ -749,7 +739,7 @@ impl CollectionUser {
pub async fn delete_all_by_user_and_org( pub async fn delete_all_by_user_and_org(
user_uuid: &UserId, user_uuid: &UserId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await; let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
@ -766,18 +756,14 @@ impl CollectionUser {
}} }}
} }
pub async fn has_access_to_collection_by_user( pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
col_id: &CollectionId,
user_uuid: &UserId,
conn: &mut DbConn,
) -> bool {
Self::find_by_collection_and_user(col_id, user_uuid, conn).await.is_some() Self::find_by_collection_and_user(col_id, user_uuid, conn).await.is_some()
} }
} }
/// Database methods /// Database methods
impl CollectionCipher { impl CollectionCipher {
pub async fn save(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { pub async fn save(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await; Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: db_run! { conn:
@ -807,7 +793,7 @@ impl CollectionCipher {
} }
} }
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await; Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
@ -821,7 +807,7 @@ impl CollectionCipher {
}} }}
} }
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -829,7 +815,7 @@ impl CollectionCipher {
}} }}
} }
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn) .execute(conn)
@ -837,7 +823,7 @@ impl CollectionCipher {
}} }}
} }
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &mut DbConn) { pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) {
if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn).await { if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn).await {
collection.update_users_revision(conn).await; collection.update_users_revision(conn).await;
} }

70
src/db/models/device.rs

@ -5,13 +5,14 @@ use derive_more::{Display, From};
use serde_json::Value; use serde_json::Value;
use super::{AuthRequest, UserId}; use super::{AuthRequest, UserId};
use crate::db::schema::devices;
use crate::{ use crate::{
crypto, crypto,
util::{format_date, get_uuid}, util::{format_date, get_uuid},
}; };
use diesel::prelude::*;
use macros::{IdFromParam, UuidFromParam}; use macros::{IdFromParam, UuidFromParam};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = devices)] #[diesel(table_name = devices)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -31,7 +32,6 @@ db_object! {
pub refresh_token: String, pub refresh_token: String,
pub twofactor_remember: Option<String>, pub twofactor_remember: Option<String>,
} }
}
/// Local methods /// Local methods
impl Device { impl Device {
@ -115,13 +115,7 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Device { impl Device {
pub async fn new( pub async fn new(uuid: DeviceId, user_uuid: UserId, name: String, atype: i32, conn: &DbConn) -> ApiResult<Device> {
uuid: DeviceId,
user_uuid: UserId,
name: String,
atype: i32,
conn: &mut DbConn,
) -> ApiResult<Device> {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
let device = Self { let device = Self {
@ -142,18 +136,24 @@ impl Device {
device.inner_save(conn).await.map(|()| device) device.inner_save(conn).await.map(|()| device)
} }
async fn inner_save(&self, conn: &mut DbConn) -> EmptyResult { async fn inner_save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
crate::util::retry( crate::util::retry(||
|| diesel::replace_into(devices::table).values(DeviceDb::to_db(self)).execute(conn), diesel::replace_into(devices::table)
.values(self)
.execute(conn),
10, 10,
).map_res("Error saving device") ).map_res("Error saving device")
} }
postgresql { postgresql {
let value = DeviceDb::to_db(self); crate::util::retry(||
crate::util::retry( diesel::insert_into(devices::table)
|| diesel::insert_into(devices::table).values(&value).on_conflict((devices::uuid, devices::user_uuid)).do_update().set(&value).execute(conn), .values(self)
.on_conflict((devices::uuid, devices::user_uuid))
.do_update()
.set(self)
.execute(conn),
10, 10,
).map_res("Error saving device") ).map_res("Error saving device")
} }
@ -161,12 +161,12 @@ impl Device {
} }
// Should only be called after user has passed authentication // Should only be called after user has passed authentication
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
self.inner_save(conn).await self.inner_save(conn).await
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid))) diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
@ -174,18 +174,17 @@ impl Device {
}} }}
} }
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.first::<DeviceDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<DeviceWithAuthRequest> { pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> {
let devices = Self::find_by_user(user_uuid, conn).await; let devices = Self::find_by_user(user_uuid, conn).await;
let mut result = Vec::new(); let mut result = Vec::new();
for device in devices { for device in devices {
@ -195,27 +194,25 @@ impl Device {
result result
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.load::<DeviceDb>(conn) .load::<Self>(conn)
.expect("Error loading devices") .expect("Error loading devices")
.from_db()
}} }}
} }
pub async fn find_by_uuid(uuid: &DeviceId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
.first::<DeviceDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &mut DbConn) -> EmptyResult { pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::update(devices::table) diesel::update(devices::table)
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
@ -224,39 +221,36 @@ impl Device {
.map_res("Error removing push token") .map_res("Error removing push token")
}} }}
} }
pub async fn find_by_refresh_token(refresh_token: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::refresh_token.eq(refresh_token)) .filter(devices::refresh_token.eq(refresh_token))
.first::<DeviceDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.order(devices::updated_at.desc()) .order(devices::updated_at.desc())
.first::<DeviceDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null()) .filter(devices::push_token.is_not_null())
.load::<DeviceDb>(conn) .load::<Self>(conn)
.expect("Error loading push devices") .expect("Error loading push devices")
.from_db()
}} }}
} }
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))

91
src/db/models/emergency_access.rs

@ -3,10 +3,11 @@ use derive_more::{AsRef, Deref, Display, From};
use serde_json::Value; use serde_json::Value;
use super::{User, UserId}; use super::{User, UserId};
use crate::db::schema::emergency_access;
use crate::{api::EmptyResult, db::DbConn, error::MapResult}; use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use macros::UuidFromParam; use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = emergency_access)] #[diesel(table_name = emergency_access)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -25,10 +26,8 @@ db_object! {
pub updated_at: NaiveDateTime, pub updated_at: NaiveDateTime,
pub created_at: NaiveDateTime, pub created_at: NaiveDateTime,
} }
}
// Local methods // Local methods
impl EmergencyAccess { impl EmergencyAccess {
pub fn new(grantor_uuid: UserId, email: String, status: i32, atype: i32, wait_time_days: i32) -> Self { pub fn new(grantor_uuid: UserId, email: String, status: i32, atype: i32, wait_time_days: i32) -> Self {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
@ -67,7 +66,7 @@ impl EmergencyAccess {
}) })
} }
pub async fn to_json_grantor_details(&self, conn: &mut DbConn) -> Value { pub async fn to_json_grantor_details(&self, conn: &DbConn) -> Value {
let grantor_user = User::find_by_uuid(&self.grantor_uuid, conn).await.expect("Grantor user not found."); let grantor_user = User::find_by_uuid(&self.grantor_uuid, conn).await.expect("Grantor user not found.");
json!({ json!({
@ -83,7 +82,7 @@ impl EmergencyAccess {
}) })
} }
pub async fn to_json_grantee_details(&self, conn: &mut DbConn) -> Option<Value> { pub async fn to_json_grantee_details(&self, conn: &DbConn) -> Option<Value> {
let grantee_user = if let Some(grantee_uuid) = &self.grantee_uuid { let grantee_user = if let Some(grantee_uuid) = &self.grantee_uuid {
User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.") User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.")
} else if let Some(email) = self.email.as_deref() { } else if let Some(email) = self.email.as_deref() {
@ -140,14 +139,14 @@ pub enum EmergencyAccessStatus {
// region Database methods // region Database methods
impl EmergencyAccess { impl EmergencyAccess {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await; User::update_uuid_revision(&self.grantor_uuid, conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(emergency_access::table) match diesel::replace_into(emergency_access::table)
.values(EmergencyAccessDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -155,7 +154,7 @@ impl EmergencyAccess {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(emergency_access::table) diesel::update(emergency_access::table)
.filter(emergency_access::uuid.eq(&self.uuid)) .filter(emergency_access::uuid.eq(&self.uuid))
.set(EmergencyAccessDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error updating emergency access") .map_res("Error updating emergency access")
} }
@ -163,12 +162,11 @@ impl EmergencyAccess {
}.map_res("Error saving emergency access") }.map_res("Error saving emergency access")
} }
postgresql { postgresql {
let value = EmergencyAccessDb::to_db(self);
diesel::insert_into(emergency_access::table) diesel::insert_into(emergency_access::table)
.values(&value) .values(&*self)
.on_conflict(emergency_access::uuid) .on_conflict(emergency_access::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving emergency access") .map_res("Error saving emergency access")
} }
@ -179,7 +177,7 @@ impl EmergencyAccess {
&mut self, &mut self,
status: i32, status: i32,
date: &NaiveDateTime, date: &NaiveDateTime,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
// Update the grantee so that it will refresh it's status. // Update the grantee so that it will refresh it's status.
User::update_uuid_revision(self.grantee_uuid.as_ref().expect("Error getting grantee"), conn).await; User::update_uuid_revision(self.grantee_uuid.as_ref().expect("Error getting grantee"), conn).await;
@ -196,11 +194,7 @@ impl EmergencyAccess {
}} }}
} }
pub async fn update_last_notification_date_and_save( pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
&mut self,
date: &NaiveDateTime,
conn: &mut DbConn,
) -> EmptyResult {
self.last_notification_at = Some(date.to_owned()); self.last_notification_at = Some(date.to_owned());
date.clone_into(&mut self.updated_at); date.clone_into(&mut self.updated_at);
@ -214,7 +208,7 @@ impl EmergencyAccess {
}} }}
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for ea in Self::find_all_by_grantor_uuid(user_uuid, conn).await { for ea in Self::find_all_by_grantor_uuid(user_uuid, conn).await {
ea.delete(conn).await?; ea.delete(conn).await?;
} }
@ -224,14 +218,14 @@ impl EmergencyAccess {
Ok(()) Ok(())
} }
pub async fn delete_all_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_grantee_email(grantee_email: &str, conn: &DbConn) -> EmptyResult {
for ea in Self::find_all_invited_by_grantee_email(grantee_email, conn).await { for ea in Self::find_all_invited_by_grantee_email(grantee_email, conn).await {
ea.delete(conn).await?; ea.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await; User::update_uuid_revision(&self.grantor_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
@ -245,109 +239,108 @@ impl EmergencyAccess {
grantor_uuid: &UserId, grantor_uuid: &UserId,
grantee_uuid: &UserId, grantee_uuid: &UserId,
email: &str, email: &str,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email))) .filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email)))
.first::<EmergencyAccessDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_all_recoveries_initiated(conn: &mut DbConn) -> Vec<Self> { pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
.filter(emergency_access::recovery_initiated_at.is_not_null()) .filter(emergency_access::recovery_initiated_at.is_not_null())
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db() .load::<Self>(conn)
.expect("Error loading emergency_access")
}} }}
} }
pub async fn find_by_uuid_and_grantor_uuid( pub async fn find_by_uuid_and_grantor_uuid(
uuid: &EmergencyAccessId, uuid: &EmergencyAccessId,
grantor_uuid: &UserId, grantor_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.first::<EmergencyAccessDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_by_uuid_and_grantee_uuid( pub async fn find_by_uuid_and_grantee_uuid(
uuid: &EmergencyAccessId, uuid: &EmergencyAccessId,
grantee_uuid: &UserId, grantee_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid)) .filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.first::<EmergencyAccessDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_by_uuid_and_grantee_email( pub async fn find_by_uuid_and_grantee_email(
uuid: &EmergencyAccessId, uuid: &EmergencyAccessId,
grantee_email: &str, grantee_email: &str,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.first::<EmergencyAccessDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid)) .filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db() .load::<Self>(conn)
.expect("Error loading emergency_access")
}} }}
} }
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.first::<EmergencyAccessDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &mut DbConn) -> Vec<Self> { pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db() .load::<Self>(conn)
.expect("Error loading emergency_access")
}} }}
} }
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.load::<EmergencyAccessDb>(conn).expect("Error loading emergency_access").from_db() .load::<Self>(conn)
.expect("Error loading emergency_access")
}} }}
} }
pub async fn accept_invite( pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult {
&mut self,
grantee_uuid: &UserId,
grantee_email: &str,
conn: &mut DbConn,
) -> EmptyResult {
if self.email.is_none() || self.email.as_ref().unwrap() != grantee_email { if self.email.is_none() || self.email.as_ref().unwrap() != grantee_email {
err!("User email does not match invite."); err!("User email does not match invite.");
} }

39
src/db/models/event.rs

@ -3,11 +3,12 @@ use chrono::{NaiveDateTime, TimeDelta, Utc};
use serde_json::Value; use serde_json::Value;
use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId}; use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId};
use crate::db::schema::{event, users_organizations};
use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG}; use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG};
use diesel::prelude::*;
// https://bitwarden.com/help/event-logs/ // https://bitwarden.com/help/event-logs/
db_object! {
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs // Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Services/Implementations/EventService.cs
// Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs // Upstream: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Api/AdminConsole/Public/Models/Response/EventResponseModel.cs
// Upstream SQL: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Sql/dbo/Tables/Event.sql // Upstream SQL: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Sql/dbo/Tables/Event.sql
@ -34,7 +35,6 @@ db_object! {
pub provider_user_uuid: Option<String>, pub provider_user_uuid: Option<String>,
pub provider_org_uuid: Option<String>, pub provider_org_uuid: Option<String>,
} }
}
// Upstream enum: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/EventType.cs // Upstream enum: https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/EventType.cs
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
@ -193,27 +193,27 @@ impl Event {
/// ############# /// #############
/// Basic Queries /// Basic Queries
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
diesel::replace_into(event::table) diesel::replace_into(event::table)
.values(EventDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving event") .map_res("Error saving event")
} }
postgresql { postgresql {
diesel::insert_into(event::table) diesel::insert_into(event::table)
.values(EventDb::to_db(self)) .values(self)
.on_conflict(event::uuid) .on_conflict(event::uuid)
.do_update() .do_update()
.set(EventDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving event") .map_res("Error saving event")
} }
} }
} }
pub async fn save_user_event(events: Vec<Event>, conn: &mut DbConn) -> EmptyResult { pub async fn save_user_event(events: Vec<Event>, conn: &DbConn) -> EmptyResult {
// Special save function which is able to handle multiple events. // Special save function which is able to handle multiple events.
// SQLite doesn't support the DEFAULT argument, and does not support inserting multiple values at the same time. // SQLite doesn't support the DEFAULT argument, and does not support inserting multiple values at the same time.
// MySQL and PostgreSQL do. // MySQL and PostgreSQL do.
@ -224,14 +224,13 @@ impl Event {
sqlite { sqlite {
for event in events { for event in events {
diesel::insert_or_ignore_into(event::table) diesel::insert_or_ignore_into(event::table)
.values(EventDb::to_db(&event)) .values(&event)
.execute(conn) .execute(conn)
.unwrap_or_default(); .unwrap_or_default();
} }
Ok(()) Ok(())
} }
mysql { mysql {
let events: Vec<EventDb> = events.iter().map(EventDb::to_db).collect();
diesel::insert_or_ignore_into(event::table) diesel::insert_or_ignore_into(event::table)
.values(&events) .values(&events)
.execute(conn) .execute(conn)
@ -239,7 +238,6 @@ impl Event {
Ok(()) Ok(())
} }
postgresql { postgresql {
let events: Vec<EventDb> = events.iter().map(EventDb::to_db).collect();
diesel::insert_into(event::table) diesel::insert_into(event::table)
.values(&events) .values(&events)
.on_conflict_do_nothing() .on_conflict_do_nothing()
@ -250,7 +248,7 @@ impl Event {
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid))) diesel::delete(event::table.filter(event::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
@ -264,7 +262,7 @@ impl Event {
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
start: &NaiveDateTime, start: &NaiveDateTime,
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
event::table event::table
@ -272,13 +270,12 @@ impl Event {
.filter(event::event_date.between(start, end)) .filter(event::event_date.between(start, end))
.order_by(event::event_date.desc()) .order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<EventDb>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
.from_db()
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
event::table event::table
.filter(event::org_uuid.eq(org_uuid)) .filter(event::org_uuid.eq(org_uuid))
@ -294,7 +291,7 @@ impl Event {
member_uuid: &MembershipId, member_uuid: &MembershipId,
start: &NaiveDateTime, start: &NaiveDateTime,
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
event::table event::table
@ -305,9 +302,8 @@ impl Event {
.select(event::all_columns) .select(event::all_columns)
.order_by(event::event_date.desc()) .order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<EventDb>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
.from_db()
}} }}
} }
@ -315,7 +311,7 @@ impl Event {
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
start: &NaiveDateTime, start: &NaiveDateTime,
end: &NaiveDateTime, end: &NaiveDateTime,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
event::table event::table
@ -323,13 +319,12 @@ impl Event {
.filter(event::event_date.between(start, end)) .filter(event::event_date.between(start, end))
.order_by(event::event_date.desc()) .order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE) .limit(Self::PAGE_SIZE)
.load::<EventDb>(conn) .load::<Self>(conn)
.expect("Error filtering events") .expect("Error filtering events")
.from_db()
}} }}
} }
pub async fn clean_events(conn: &mut DbConn) -> EmptyResult { pub async fn clean_events(conn: &DbConn) -> EmptyResult {
if let Some(days_to_retain) = CONFIG.events_days_retain() { if let Some(days_to_retain) = CONFIG.events_days_retain() {
let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap(); let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap();
db_run! { conn: { db_run! { conn: {

18
src/db/models/favorite.rs

@ -1,6 +1,7 @@
use super::{CipherId, User, UserId}; use super::{CipherId, User, UserId};
use crate::db::schema::favorites;
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable)] #[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = favorites)] #[diesel(table_name = favorites)]
#[diesel(primary_key(user_uuid, cipher_uuid))] #[diesel(primary_key(user_uuid, cipher_uuid))]
@ -8,7 +9,6 @@ db_object! {
pub user_uuid: UserId, pub user_uuid: UserId,
pub cipher_uuid: CipherId, pub cipher_uuid: CipherId,
} }
}
use crate::db::DbConn; use crate::db::DbConn;
@ -17,14 +17,16 @@ use crate::error::MapResult;
impl Favorite { impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user. // Returns whether the specified cipher is a favorite of the specified user.
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: { db_run! { conn: {
let query = favorites::table let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid)) .filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::user_uuid.eq(user_uuid)) .filter(favorites::user_uuid.eq(user_uuid))
.count(); .count();
query.first::<i64>(conn).ok().unwrap_or(0) != 0 query.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}} }}
} }
@ -33,7 +35,7 @@ impl Favorite {
favorite: bool, favorite: bool,
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
user_uuid: &UserId, user_uuid: &UserId,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
let (old, new) = (Self::is_favorite(cipher_uuid, user_uuid, conn).await, favorite); let (old, new) = (Self::is_favorite(cipher_uuid, user_uuid, conn).await, favorite);
match (old, new) { match (old, new) {
@ -67,7 +69,7 @@ impl Favorite {
} }
// Delete all favorite entries associated with the specified cipher. // Delete all favorite entries associated with the specified cipher.
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid))) diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -76,7 +78,7 @@ impl Favorite {
} }
// Delete all favorite entries associated with the specified user. // Delete all favorite entries associated with the specified user.
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid))) diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
@ -86,7 +88,7 @@ impl Favorite {
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers /// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches. /// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<CipherId> { pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> {
db_run! { conn: { db_run! { conn: {
favorites::table favorites::table
.filter(favorites::user_uuid.eq(user_uuid)) .filter(favorites::user_uuid.eq(user_uuid))

53
src/db/models/folder.rs

@ -3,9 +3,10 @@ use derive_more::{AsRef, Deref, Display, From};
use serde_json::Value; use serde_json::Value;
use super::{CipherId, User, UserId}; use super::{CipherId, User, UserId};
use crate::db::schema::{folders, folders_ciphers};
use diesel::prelude::*;
use macros::UuidFromParam; use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = folders)] #[diesel(table_name = folders)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
@ -24,7 +25,6 @@ db_object! {
pub cipher_uuid: CipherId, pub cipher_uuid: CipherId,
pub folder_uuid: FolderId, pub folder_uuid: FolderId,
} }
}
/// Local methods /// Local methods
impl Folder { impl Folder {
@ -69,14 +69,14 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Folder { impl Folder {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(folders::table) match diesel::replace_into(folders::table)
.values(FolderDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -84,7 +84,7 @@ impl Folder {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(folders::table) diesel::update(folders::table)
.filter(folders::uuid.eq(&self.uuid)) .filter(folders::uuid.eq(&self.uuid))
.set(FolderDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving folder") .map_res("Error saving folder")
} }
@ -92,19 +92,18 @@ impl Folder {
}.map_res("Error saving folder") }.map_res("Error saving folder")
} }
postgresql { postgresql {
let value = FolderDb::to_db(self);
diesel::insert_into(folders::table) diesel::insert_into(folders::table)
.values(&value) .values(&*self)
.on_conflict(folders::uuid) .on_conflict(folders::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving folder") .map_res("Error saving folder")
} }
} }
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn).await?; FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
@ -115,50 +114,48 @@ impl Folder {
}} }}
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for folder in Self::find_by_user(user_uuid, conn).await { for folder in Self::find_by_user(user_uuid, conn).await {
folder.delete(conn).await?; folder.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
folders::table folders::table
.filter(folders::uuid.eq(uuid)) .filter(folders::uuid.eq(uuid))
.filter(folders::user_uuid.eq(user_uuid)) .filter(folders::user_uuid.eq(user_uuid))
.first::<FolderDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
folders::table folders::table
.filter(folders::user_uuid.eq(user_uuid)) .filter(folders::user_uuid.eq(user_uuid))
.load::<FolderDb>(conn) .load::<Self>(conn)
.expect("Error loading folders") .expect("Error loading folders")
.from_db()
}} }}
} }
} }
impl FolderCipher { impl FolderCipher {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
// Not checking for ForeignKey Constraints here. // Not checking for ForeignKey Constraints here.
// Table folders_ciphers does not have ForeignKey Constraints which would cause conflicts. // Table folders_ciphers does not have ForeignKey Constraints which would cause conflicts.
// This table has no constraints pointing to itself, but only to others. // This table has no constraints pointing to itself, but only to others.
diesel::replace_into(folders_ciphers::table) diesel::replace_into(folders_ciphers::table)
.values(FolderCipherDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
.map_res("Error adding cipher to folder") .map_res("Error adding cipher to folder")
} }
postgresql { postgresql {
diesel::insert_into(folders_ciphers::table) diesel::insert_into(folders_ciphers::table)
.values(FolderCipherDb::to_db(self)) .values(self)
.on_conflict((folders_ciphers::cipher_uuid, folders_ciphers::folder_uuid)) .on_conflict((folders_ciphers::cipher_uuid, folders_ciphers::folder_uuid))
.do_nothing() .do_nothing()
.execute(conn) .execute(conn)
@ -167,7 +164,7 @@ impl FolderCipher {
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete( diesel::delete(
folders_ciphers::table folders_ciphers::table
@ -179,7 +176,7 @@ impl FolderCipher {
}} }}
} }
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -187,7 +184,7 @@ impl FolderCipher {
}} }}
} }
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn) .execute(conn)
@ -198,31 +195,29 @@ impl FolderCipher {
pub async fn find_by_folder_and_cipher( pub async fn find_by_folder_and_cipher(
folder_uuid: &FolderId, folder_uuid: &FolderId,
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)) .filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))
.first::<FolderCipherDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.load::<FolderCipherDb>(conn) .load::<Self>(conn)
.expect("Error loading folders") .expect("Error loading folders")
.from_db()
}} }}
} }
/// Return a vec with (cipher_uuid, folder_uuid) /// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches. /// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<(CipherId, FolderId)> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.inner_join(folders::table) .inner_join(folders::table)

95
src/db/models/group.rs

@ -1,13 +1,14 @@
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId}; use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use crate::api::EmptyResult; use crate::api::EmptyResult;
use crate::db::schema::{collections_groups, groups, groups_users, users_organizations};
use crate::db::DbConn; use crate::db::DbConn;
use crate::error::MapResult; use crate::error::MapResult;
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam; use macros::UuidFromParam;
use serde_json::Value; use serde_json::Value;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = groups)] #[diesel(table_name = groups)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -38,8 +39,7 @@ db_object! {
#[diesel(primary_key(groups_uuid, users_organizations_uuid))] #[diesel(primary_key(groups_uuid, users_organizations_uuid))]
pub struct GroupUser { pub struct GroupUser {
pub groups_uuid: GroupId, pub groups_uuid: GroupId,
pub users_organizations_uuid: MembershipId pub users_organizations_uuid: MembershipId,
}
} }
/// Local methods /// Local methods
@ -77,7 +77,7 @@ impl Group {
}) })
} }
pub async fn to_json_details(&self, conn: &mut DbConn) -> Value { pub async fn to_json_details(&self, conn: &DbConn) -> Value {
// If both read_only and hide_passwords are false, then manage should be true // If both read_only and hide_passwords are false, then manage should be true
// You can't have an entry with read_only and manage, or hide_passwords and manage // You can't have an entry with read_only and manage, or hide_passwords and manage
// Or an entry with everything to false // Or an entry with everything to false
@ -156,13 +156,13 @@ impl GroupUser {
/// Database methods /// Database methods
impl Group { impl Group {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.revision_date = Utc::now().naive_utc(); self.revision_date = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(groups::table) match diesel::replace_into(groups::table)
.values(GroupDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -170,7 +170,7 @@ impl Group {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(groups::table) diesel::update(groups::table)
.filter(groups::uuid.eq(&self.uuid)) .filter(groups::uuid.eq(&self.uuid))
.set(GroupDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving group") .map_res("Error saving group")
} }
@ -178,36 +178,34 @@ impl Group {
}.map_res("Error saving group") }.map_res("Error saving group")
} }
postgresql { postgresql {
let value = GroupDb::to_db(self);
diesel::insert_into(groups::table) diesel::insert_into(groups::table)
.values(&value) .values(&*self)
.on_conflict(groups::uuid) .on_conflict(groups::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving group") .map_res("Error saving group")
} }
} }
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for group in Self::find_by_organization(org_uuid, conn).await { for group in Self::find_by_organization(org_uuid, conn).await {
group.delete(conn).await?; group.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
groups::table groups::table
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.load::<GroupDb>(conn) .load::<Self>(conn)
.expect("Error loading groups") .expect("Error loading groups")
.from_db()
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
groups::table groups::table
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
@ -218,33 +216,31 @@ impl Group {
}} }}
} }
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
groups::table groups::table
.filter(groups::uuid.eq(uuid)) .filter(groups::uuid.eq(uuid))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.first::<GroupDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_external_id_and_org( pub async fn find_by_external_id_and_org(
external_id: &str, external_id: &str,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
groups::table groups::table
.filter(groups::external_id.eq(external_id)) .filter(groups::external_id.eq(external_id))
.filter(groups::organizations_uuid.eq(org_uuid)) .filter(groups::organizations_uuid.eq(org_uuid))
.first::<GroupDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
//Returns all organizations the user has full access to //Returns all organizations the user has full access to
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &mut DbConn) -> Vec<OrganizationId> { pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: { db_run! { conn: {
groups_users::table groups_users::table
.inner_join(users_organizations::table.on( .inner_join(users_organizations::table.on(
@ -262,7 +258,7 @@ impl Group {
}} }}
} }
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &mut DbConn) -> bool { pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool {
db_run! { conn: { db_run! { conn: {
groups::table groups::table
.inner_join(groups_users::table.on( .inner_join(groups_users::table.on(
@ -280,7 +276,7 @@ impl Group {
}} }}
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
CollectionGroup::delete_all_by_group(&self.uuid, conn).await?; CollectionGroup::delete_all_by_group(&self.uuid, conn).await?;
GroupUser::delete_all_by_group(&self.uuid, conn).await?; GroupUser::delete_all_by_group(&self.uuid, conn).await?;
@ -291,13 +287,13 @@ impl Group {
}} }}
} }
pub async fn update_revision(uuid: &GroupId, conn: &mut DbConn) { pub async fn update_revision(uuid: &GroupId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}"); warn!("Failed to update revision for {uuid}: {e:#?}");
} }
} }
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult { async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
crate::util::retry(|| { crate::util::retry(|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid))) diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
@ -310,7 +306,7 @@ impl Group {
} }
impl CollectionGroup { impl CollectionGroup {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await; let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await;
for group_user in group_users { for group_user in group_users {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
@ -369,17 +365,16 @@ impl CollectionGroup {
} }
} }
pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_group(group_uuid: &GroupId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections_groups::table collections_groups::table
.filter(collections_groups::groups_uuid.eq(group_uuid)) .filter(collections_groups::groups_uuid.eq(group_uuid))
.load::<CollectionGroupDb>(conn) .load::<Self>(conn)
.expect("Error loading collection groups") .expect("Error loading collection groups")
.from_db()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections_groups::table collections_groups::table
.inner_join(groups_users::table.on( .inner_join(groups_users::table.on(
@ -390,24 +385,22 @@ impl CollectionGroup {
)) ))
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.select(collections_groups::all_columns) .select(collections_groups::all_columns)
.load::<CollectionGroupDb>(conn) .load::<Self>(conn)
.expect("Error loading user collection groups") .expect("Error loading user collection groups")
.from_db()
}} }}
} }
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections_groups::table collections_groups::table
.filter(collections_groups::collections_uuid.eq(collection_uuid)) .filter(collections_groups::collections_uuid.eq(collection_uuid))
.select(collections_groups::all_columns) .select(collections_groups::all_columns)
.load::<CollectionGroupDb>(conn) .load::<Self>(conn)
.expect("Error loading collection groups") .expect("Error loading collection groups")
.from_db()
}} }}
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await; let group_users = GroupUser::find_by_group(&self.groups_uuid, conn).await;
for group_user in group_users { for group_user in group_users {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
@ -422,7 +415,7 @@ impl CollectionGroup {
}} }}
} }
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(group_uuid, conn).await; let group_users = GroupUser::find_by_group(group_uuid, conn).await;
for group_user in group_users { for group_user in group_users {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
@ -436,7 +429,7 @@ impl CollectionGroup {
}} }}
} }
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
let collection_assigned_to_groups = CollectionGroup::find_by_collection(collection_uuid, conn).await; let collection_assigned_to_groups = CollectionGroup::find_by_collection(collection_uuid, conn).await;
for collection_assigned_to_group in collection_assigned_to_groups { for collection_assigned_to_group in collection_assigned_to_groups {
let group_users = GroupUser::find_by_group(&collection_assigned_to_group.groups_uuid, conn).await; let group_users = GroupUser::find_by_group(&collection_assigned_to_group.groups_uuid, conn).await;
@ -455,7 +448,7 @@ impl CollectionGroup {
} }
impl GroupUser { impl GroupUser {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_user_revision(conn).await; self.update_user_revision(conn).await;
db_run! { conn: db_run! { conn:
@ -501,30 +494,28 @@ impl GroupUser {
} }
} }
pub async fn find_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_group(group_uuid: &GroupId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
groups_users::table groups_users::table
.filter(groups_users::groups_uuid.eq(group_uuid)) .filter(groups_users::groups_uuid.eq(group_uuid))
.load::<GroupUserDb>(conn) .load::<Self>(conn)
.expect("Error loading group users") .expect("Error loading group users")
.from_db()
}} }}
} }
pub async fn find_by_member(member_uuid: &MembershipId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
groups_users::table groups_users::table
.filter(groups_users::users_organizations_uuid.eq(member_uuid)) .filter(groups_users::users_organizations_uuid.eq(member_uuid))
.load::<GroupUserDb>(conn) .load::<Self>(conn)
.expect("Error loading groups for user") .expect("Error loading groups for user")
.from_db()
}} }}
} }
pub async fn has_access_to_collection_by_member( pub async fn has_access_to_collection_by_member(
collection_uuid: &CollectionId, collection_uuid: &CollectionId,
member_uuid: &MembershipId, member_uuid: &MembershipId,
conn: &mut DbConn, conn: &DbConn,
) -> bool { ) -> bool {
db_run! { conn: { db_run! { conn: {
groups_users::table groups_users::table
@ -542,7 +533,7 @@ impl GroupUser {
pub async fn has_full_access_by_member( pub async fn has_full_access_by_member(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
member_uuid: &MembershipId, member_uuid: &MembershipId,
conn: &mut DbConn, conn: &DbConn,
) -> bool { ) -> bool {
db_run! { conn: { db_run! { conn: {
groups_users::table groups_users::table
@ -558,7 +549,7 @@ impl GroupUser {
}} }}
} }
pub async fn update_user_revision(&self, conn: &mut DbConn) { pub async fn update_user_revision(&self, conn: &DbConn) {
match Membership::find_by_uuid(&self.users_organizations_uuid, conn).await { match Membership::find_by_uuid(&self.users_organizations_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await, Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"), None => warn!("Member could not be found!"),
@ -568,7 +559,7 @@ impl GroupUser {
pub async fn delete_by_group_and_member( pub async fn delete_by_group_and_member(
group_uuid: &GroupId, group_uuid: &GroupId,
member_uuid: &MembershipId, member_uuid: &MembershipId,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
match Membership::find_by_uuid(member_uuid, conn).await { match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await, Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
@ -584,7 +575,7 @@ impl GroupUser {
}} }}
} }
pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_group(group_uuid: &GroupId, conn: &DbConn) -> EmptyResult {
let group_users = GroupUser::find_by_group(group_uuid, conn).await; let group_users = GroupUser::find_by_group(group_uuid, conn).await;
for group_user in group_users { for group_user in group_users {
group_user.update_user_revision(conn).await; group_user.update_user_revision(conn).await;
@ -598,7 +589,7 @@ impl GroupUser {
}} }}
} }
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult {
match Membership::find_by_uuid(member_uuid, conn).await { match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await, Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"), None => warn!("Member could not be found!"),

56
src/db/models/org_policy.rs

@ -3,12 +3,13 @@ use serde::Deserialize;
use serde_json::Value; use serde_json::Value;
use crate::api::EmptyResult; use crate::api::EmptyResult;
use crate::db::schema::{org_policies, users_organizations};
use crate::db::DbConn; use crate::db::DbConn;
use crate::error::MapResult; use crate::error::MapResult;
use diesel::prelude::*;
use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId}; use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = org_policies)] #[diesel(table_name = org_policies)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
@ -19,7 +20,6 @@ db_object! {
pub enabled: bool, pub enabled: bool,
pub data: String, pub data: String,
} }
}
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/PolicyType.cs // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/PolicyType.cs
#[derive(Copy, Clone, Eq, PartialEq, num_derive::FromPrimitive)] #[derive(Copy, Clone, Eq, PartialEq, num_derive::FromPrimitive)]
@ -106,11 +106,11 @@ impl OrgPolicy {
/// Database methods /// Database methods
impl OrgPolicy { impl OrgPolicy {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(org_policies::table) match diesel::replace_into(org_policies::table)
.values(OrgPolicyDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -118,7 +118,7 @@ impl OrgPolicy {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(org_policies::table) diesel::update(org_policies::table)
.filter(org_policies::uuid.eq(&self.uuid)) .filter(org_policies::uuid.eq(&self.uuid))
.set(OrgPolicyDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving org_policy") .map_res("Error saving org_policy")
} }
@ -126,7 +126,6 @@ impl OrgPolicy {
}.map_res("Error saving org_policy") }.map_res("Error saving org_policy")
} }
postgresql { postgresql {
let value = OrgPolicyDb::to_db(self);
// We need to make sure we're not going to violate the unique constraint on org_uuid and atype. // We need to make sure we're not going to violate the unique constraint on org_uuid and atype.
// This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does
// not support multiple constraints on ON CONFLICT clauses. // not support multiple constraints on ON CONFLICT clauses.
@ -139,17 +138,17 @@ impl OrgPolicy {
.map_res("Error deleting org_policy for insert")?; .map_res("Error deleting org_policy for insert")?;
diesel::insert_into(org_policies::table) diesel::insert_into(org_policies::table)
.values(&value) .values(self)
.on_conflict(org_policies::uuid) .on_conflict(org_policies::uuid)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving org_policy") .map_res("Error saving org_policy")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid))) diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
@ -157,17 +156,16 @@ impl OrgPolicy {
}} }}
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
.load::<OrgPolicyDb>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
.from_db()
}} }}
} }
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.inner_join( .inner_join(
@ -179,28 +177,26 @@ impl OrgPolicy {
users_organizations::status.eq(MembershipStatus::Confirmed as i32) users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
.from_db()
}} }}
} }
pub async fn find_by_org_and_type( pub async fn find_by_org_and_type(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.first::<OrgPolicyDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid))) diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn) .execute(conn)
@ -229,16 +225,15 @@ impl OrgPolicy {
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true)) .filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
.from_db()
}} }}
} }
pub async fn find_confirmed_by_user_and_active_policy( pub async fn find_confirmed_by_user_and_active_policy(
user_uuid: &UserId, user_uuid: &UserId,
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
@ -253,9 +248,8 @@ impl OrgPolicy {
.filter(org_policies::atype.eq(policy_type as i32)) .filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true)) .filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns) .select(org_policies::all_columns)
.load::<OrgPolicyDb>(conn) .load::<Self>(conn)
.expect("Error loading org_policy") .expect("Error loading org_policy")
.from_db()
}} }}
} }
@ -266,7 +260,7 @@ impl OrgPolicy {
user_uuid: &UserId, user_uuid: &UserId,
policy_type: OrgPolicyType, policy_type: OrgPolicyType,
exclude_org_uuid: Option<&OrganizationId>, exclude_org_uuid: Option<&OrganizationId>,
conn: &mut DbConn, conn: &DbConn,
) -> bool { ) -> bool {
for policy in for policy in
OrgPolicy::find_accepted_and_confirmed_by_user_and_active_policy(user_uuid, policy_type, conn).await OrgPolicy::find_accepted_and_confirmed_by_user_and_active_policy(user_uuid, policy_type, conn).await
@ -289,7 +283,7 @@ impl OrgPolicy {
user_uuid: &UserId, user_uuid: &UserId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
exclude_current_org: bool, exclude_current_org: bool,
conn: &mut DbConn, conn: &DbConn,
) -> OrgPolicyResult { ) -> OrgPolicyResult {
// Enforce TwoFactor/TwoStep login // Enforce TwoFactor/TwoStep login
if TwoFactor::find_by_user(user_uuid, conn).await.is_empty() { if TwoFactor::find_by_user(user_uuid, conn).await.is_empty() {
@ -315,7 +309,7 @@ impl OrgPolicy {
Ok(()) Ok(())
} }
pub async fn org_is_reset_password_auto_enroll(org_uuid: &OrganizationId, conn: &mut DbConn) -> bool { pub async fn org_is_reset_password_auto_enroll(org_uuid: &OrganizationId, conn: &DbConn) -> bool {
match OrgPolicy::find_by_org_and_type(org_uuid, OrgPolicyType::ResetPassword, conn).await { match OrgPolicy::find_by_org_and_type(org_uuid, OrgPolicyType::ResetPassword, conn).await {
Some(policy) => match serde_json::from_str::<ResetPasswordDataModel>(&policy.data) { Some(policy) => match serde_json::from_str::<ResetPasswordDataModel>(&policy.data) {
Ok(opts) => { Ok(opts) => {
@ -331,7 +325,7 @@ impl OrgPolicy {
/// Returns true if the user belongs to an org that has enabled the `DisableHideEmail` /// Returns true if the user belongs to an org that has enabled the `DisableHideEmail`
/// option of the `Send Options` policy, and the user is not an owner or admin of that org. /// option of the `Send Options` policy, and the user is not an owner or admin of that org.
pub async fn is_hide_email_disabled(user_uuid: &UserId, conn: &mut DbConn) -> bool { pub async fn is_hide_email_disabled(user_uuid: &UserId, conn: &DbConn) -> bool {
for policy in for policy in
OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await
{ {
@ -351,11 +345,7 @@ impl OrgPolicy {
false false
} }
pub async fn is_enabled_for_member( pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
member_uuid: &MembershipId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
) -> bool {
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await { if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await {
if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await { if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await {
return policy.enabled; return policy.enabled;

247
src/db/models/organization.rs

@ -1,5 +1,6 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use serde_json::Value; use serde_json::Value;
use std::{ use std::{
@ -11,10 +12,13 @@ use super::{
CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy, CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
OrgPolicyType, TwoFactor, User, UserId, OrgPolicyType, TwoFactor, User, UserId,
}; };
use crate::db::schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
};
use crate::CONFIG; use crate::CONFIG;
use macros::UuidFromParam; use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = organizations)] #[diesel(table_name = organizations)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -56,7 +60,6 @@ db_object! {
pub api_key: String, pub api_key: String,
pub revision_date: NaiveDateTime, pub revision_date: NaiveDateTime,
} }
}
// https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/OrganizationUserStatusType.cs // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/AdminConsole/Enums/OrganizationUserStatusType.cs
#[derive(PartialEq)] #[derive(PartialEq)]
@ -325,7 +328,7 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Organization { impl Organization {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.billing_email) { if !crate::util::is_valid_email(&self.billing_email) {
err!(format!("BillingEmail {} is not a valid email address", self.billing_email)) err!(format!("BillingEmail {} is not a valid email address", self.billing_email))
} }
@ -337,7 +340,7 @@ impl Organization {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(organizations::table) match diesel::replace_into(organizations::table)
.values(OrganizationDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -345,7 +348,7 @@ impl Organization {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(organizations::table) diesel::update(organizations::table)
.filter(organizations::uuid.eq(&self.uuid)) .filter(organizations::uuid.eq(&self.uuid))
.set(OrganizationDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving organization") .map_res("Error saving organization")
} }
@ -354,19 +357,18 @@ impl Organization {
} }
postgresql { postgresql {
let value = OrganizationDb::to_db(self);
diesel::insert_into(organizations::table) diesel::insert_into(organizations::table)
.values(&value) .values(self)
.on_conflict(organizations::uuid) .on_conflict(organizations::uuid)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving organization") .map_res("Error saving organization")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection}; use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn).await?; Cipher::delete_all_by_organization(&self.uuid, conn).await?;
@ -383,31 +385,33 @@ impl Organization {
}} }}
} }
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table organizations::table
.filter(organizations::uuid.eq(uuid)) .filter(organizations::uuid.eq(uuid))
.first::<OrganizationDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_by_name(name: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table organizations::table
.filter(organizations::name.eq(name)) .filter(organizations::name.eq(name))
.first::<OrganizationDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn get_all(conn: &mut DbConn) -> Vec<Self> { pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table.load::<OrganizationDb>(conn).expect("Error loading organizations").from_db() organizations::table
.load::<Self>(conn)
.expect("Error loading organizations")
}} }}
} }
pub async fn find_main_org_user_email(user_email: &str, conn: &mut DbConn) -> Option<Organization> { pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = user_email.to_lowercase(); let lower_mail = user_email.to_lowercase();
db_run! { conn: { db_run! { conn: {
@ -418,12 +422,12 @@ impl Organization {
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32)) .filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc()) .order(users_organizations::atype.asc())
.select(organizations::all_columns) .select(organizations::all_columns)
.first::<OrganizationDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_org_user_email(user_email: &str, conn: &mut DbConn) -> Vec<Organization> { pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> {
let lower_mail = user_email.to_lowercase(); let lower_mail = user_email.to_lowercase();
db_run! { conn: { db_run! { conn: {
@ -434,15 +438,14 @@ impl Organization {
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32)) .filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc()) .order(users_organizations::atype.asc())
.select(organizations::all_columns) .select(organizations::all_columns)
.load::<OrganizationDb>(conn) .load::<Self>(conn)
.expect("Error loading user orgs") .expect("Error loading user orgs")
.from_db()
}} }}
} }
} }
impl Membership { impl Membership {
pub async fn to_json(&self, conn: &mut DbConn) -> Value { pub async fn to_json(&self, conn: &DbConn) -> Value {
let org = Organization::find_by_uuid(&self.org_uuid, conn).await.unwrap(); let org = Organization::find_by_uuid(&self.org_uuid, conn).await.unwrap();
// HACK: Convert the manager type to a custom type // HACK: Convert the manager type to a custom type
@ -533,12 +536,7 @@ impl Membership {
}) })
} }
pub async fn to_json_user_details( pub async fn to_json_user_details(&self, include_collections: bool, include_groups: bool, conn: &DbConn) -> Value {
&self,
include_collections: bool,
include_groups: bool,
conn: &mut DbConn,
) -> Value {
let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap(); let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap();
// Because BitWarden want the status to be -1 for revoked users we need to catch that here. // Because BitWarden want the status to be -1 for revoked users we need to catch that here.
@ -680,7 +678,7 @@ impl Membership {
}) })
} }
pub async fn to_json_details(&self, conn: &mut DbConn) -> Value { pub async fn to_json_details(&self, conn: &DbConn) -> Value {
let coll_uuids = if self.access_all { let coll_uuids = if self.access_all {
vec![] // If we have complete access, no need to fill the array vec![] // If we have complete access, no need to fill the array
} else { } else {
@ -720,7 +718,7 @@ impl Membership {
}) })
} }
pub async fn to_json_mini_details(&self, conn: &mut DbConn) -> Value { pub async fn to_json_mini_details(&self, conn: &DbConn) -> Value {
let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap(); let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap();
// Because Bitwarden wants the status to be -1 for revoked users we need to catch that here. // Because Bitwarden wants the status to be -1 for revoked users we need to catch that here.
@ -742,13 +740,13 @@ impl Membership {
}) })
} }
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(users_organizations::table) match diesel::replace_into(users_organizations::table)
.values(MembershipDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -756,7 +754,7 @@ impl Membership {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(users_organizations::table) diesel::update(users_organizations::table)
.filter(users_organizations::uuid.eq(&self.uuid)) .filter(users_organizations::uuid.eq(&self.uuid))
.set(MembershipDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error adding user to organization") .map_res("Error adding user to organization")
}, },
@ -764,19 +762,18 @@ impl Membership {
}.map_res("Error adding user to organization") }.map_res("Error adding user to organization")
} }
postgresql { postgresql {
let value = MembershipDb::to_db(self);
diesel::insert_into(users_organizations::table) diesel::insert_into(users_organizations::table)
.values(&value) .values(self)
.on_conflict(users_organizations::uuid) .on_conflict(users_organizations::uuid)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error adding user to organization") .map_res("Error adding user to organization")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await; User::update_uuid_revision(&self.user_uuid, conn).await;
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?; CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
@ -789,25 +786,21 @@ impl Membership {
}} }}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
for member in Self::find_by_org(org_uuid, conn).await { for member in Self::find_by_org(org_uuid, conn).await {
member.delete(conn).await?; member.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for member in Self::find_any_state_by_user(user_uuid, conn).await { for member in Self::find_any_state_by_user(user_uuid, conn).await {
member.delete(conn).await?; member.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn find_by_email_and_org( pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> {
email: &str,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Membership> {
if let Some(user) = User::find_by_mail(email, conn).await { if let Some(user) = User::find_by_mail(email, conn).await {
if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await { if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await {
return Some(member); return Some(member);
@ -829,52 +822,48 @@ impl Membership {
(self.access_all || self.atype >= MembershipType::Admin) && self.has_status(MembershipStatus::Confirmed) (self.access_all || self.atype >= MembershipType::Admin) && self.has_status(MembershipStatus::Confirmed)
} }
pub async fn find_by_uuid(uuid: &MembershipId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::uuid.eq(uuid)) .filter(users_organizations::uuid.eq(uuid))
.first::<MembershipDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_by_uuid_and_org( pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
uuid: &MembershipId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::uuid.eq(uuid)) .filter(users_organizations::uuid.eq(uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.first::<MembershipDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32)) .filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
// Should be used only when email are disabled. // Should be used only when email are disabled.
// In Organizations::send_invite status is set to Accepted only if the user has a password. // In Organizations::send_invite status is set to Accepted only if the user has a password.
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::update(users_organizations::table) diesel::update(users_organizations::table)
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -885,16 +874,16 @@ impl Membership {
}} }}
} }
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
pub async fn count_accepted_and_confirmed_by_user(user_uuid: &UserId, conn: &mut DbConn) -> i64 { pub async fn count_accepted_and_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -905,27 +894,27 @@ impl Membership {
}} }}
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.expect("Error loading user organizations").from_db() .expect("Error loading user organizations")
}} }}
} }
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32)) .filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
// Get all users which are either owner or admin, or a manager which can manage/access all // Get all users which are either owner or admin, or a manager which can manage/access all
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -934,12 +923,12 @@ impl Membership {
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]) users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true))) .or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true)))
) )
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> i64 { pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -950,24 +939,20 @@ impl Membership {
}} }}
} }
pub async fn find_by_org_and_type( pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> {
org_uuid: &OrganizationId,
atype: MembershipType,
conn: &mut DbConn,
) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32)) .filter(users_organizations::atype.eq(atype as i32))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.expect("Error loading user organizations").from_db() .expect("Error loading user organizations")
}} }}
} }
pub async fn count_confirmed_by_org_and_type( pub async fn count_confirmed_by_org_and_type(
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
atype: MembershipType, atype: MembershipType,
conn: &mut DbConn, conn: &DbConn,
) -> i64 { ) -> i64 {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
@ -980,24 +965,20 @@ impl Membership {
}} }}
} }
pub async fn find_by_user_and_org( pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
user_uuid: &UserId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
.first::<MembershipDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_confirmed_by_user_and_org( pub async fn find_confirmed_by_user_and_org(
user_uuid: &UserId, user_uuid: &UserId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Option<Self> { ) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
@ -1006,21 +987,21 @@ impl Membership {
.filter( .filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32) users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.first::<MembershipDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.load::<MembershipDb>(conn) .load::<Self>(conn)
.expect("Error loading user organizations").from_db() .expect("Error loading user organizations")
}} }}
} }
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<OrganizationId> { pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -1030,11 +1011,7 @@ impl Membership {
}} }}
} }
pub async fn find_by_user_and_policy( pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
user_uuid: &UserId,
policy_type: OrgPolicyType,
conn: &mut DbConn,
) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.inner_join( .inner_join(
@ -1048,16 +1025,12 @@ impl Membership {
users_organizations::status.eq(MembershipStatus::Confirmed as i32) users_organizations::status.eq(MembershipStatus::Confirmed as i32)
) )
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.load::<MembershipDb>(conn) .load::<Self>(conn)
.unwrap_or_default().from_db() .unwrap_or_default()
}} }}
} }
pub async fn find_by_cipher_and_org( pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
cipher_uuid: &CipherId,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -1076,14 +1049,15 @@ impl Membership {
) )
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.distinct() .distinct()
.load::<MembershipDb>(conn).expect("Error loading user organizations").from_db() .load::<Self>(conn)
.expect("Error loading user organizations")
}} }}
} }
pub async fn find_by_cipher_and_org_with_group( pub async fn find_by_cipher_and_org_with_group(
cipher_uuid: &CipherId, cipher_uuid: &CipherId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
@ -1106,15 +1080,12 @@ impl Membership {
) )
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.distinct() .distinct()
.load::<MembershipDb>(conn).expect("Error loading user organizations with groups").from_db() .load::<Self>(conn)
.expect("Error loading user organizations with groups")
}} }}
} }
pub async fn user_has_ge_admin_access_to_cipher( pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool {
user_uuid: &UserId,
cipher_uuid: &CipherId,
conn: &mut DbConn,
) -> bool {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable())))) .inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))))
@ -1122,14 +1093,15 @@ impl Membership {
.filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])) .filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]))
.count() .count()
.first::<i64>(conn) .first::<i64>(conn)
.ok().unwrap_or(0) != 0 .ok()
.unwrap_or(0) != 0
}} }}
} }
pub async fn find_by_collection_and_org( pub async fn find_by_collection_and_org(
collection_uuid: &CollectionId, collection_uuid: &CollectionId,
org_uuid: &OrganizationId, org_uuid: &OrganizationId,
conn: &mut DbConn, conn: &DbConn,
) -> Vec<Self> { ) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
@ -1143,33 +1115,31 @@ impl Membership {
) )
) )
.select(users_organizations::all_columns) .select(users_organizations::all_columns)
.load::<MembershipDb>(conn).expect("Error loading user organizations").from_db() .load::<Self>(conn)
.expect("Error loading user organizations")
}} }}
} }
pub async fn find_by_external_id_and_org( pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
ext_id: &str,
org_uuid: &OrganizationId,
conn: &mut DbConn,
) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter( .filter(
users_organizations::external_id.eq(ext_id) users_organizations::external_id.eq(ext_id)
.and(users_organizations::org_uuid.eq(org_uuid)) .and(users_organizations::org_uuid.eq(org_uuid))
) )
.first::<MembershipDb>(conn).ok().from_db() .first::<Self>(conn)
.ok()
}} }}
} }
pub async fn find_main_user_org(user_uuid: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32)) .filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc()) .order(users_organizations::atype.asc())
.first::<MembershipDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
} }
@ -1179,7 +1149,7 @@ impl OrganizationApiKey {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(organization_api_key::table) match diesel::replace_into(organization_api_key::table)
.values(OrganizationApiKeyDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -1187,7 +1157,7 @@ impl OrganizationApiKey {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(organization_api_key::table) diesel::update(organization_api_key::table)
.filter(organization_api_key::uuid.eq(&self.uuid)) .filter(organization_api_key::uuid.eq(&self.uuid))
.set(OrganizationApiKeyDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving organization") .map_res("Error saving organization")
} }
@ -1196,12 +1166,11 @@ impl OrganizationApiKey {
} }
postgresql { postgresql {
let value = OrganizationApiKeyDb::to_db(self);
diesel::insert_into(organization_api_key::table) diesel::insert_into(organization_api_key::table)
.values(&value) .values(self)
.on_conflict((organization_api_key::uuid, organization_api_key::org_uuid)) .on_conflict((organization_api_key::uuid, organization_api_key::org_uuid))
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving organization") .map_res("Error saving organization")
} }
@ -1212,12 +1181,12 @@ impl OrganizationApiKey {
db_run! { conn: { db_run! { conn: {
organization_api_key::table organization_api_key::table
.filter(organization_api_key::org_uuid.eq(org_uuid)) .filter(organization_api_key::org_uuid.eq(org_uuid))
.first::<OrganizationApiKeyDb>(conn) .first::<Self>(conn)
.ok().from_db() .ok()
}} }}
} }
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid))) diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)))
.execute(conn) .execute(conn)

56
src/db/models/send.rs

@ -4,9 +4,10 @@ use serde_json::Value;
use crate::{config::PathType, util::LowerCase, CONFIG}; use crate::{config::PathType, util::LowerCase, CONFIG};
use super::{OrganizationId, User, UserId}; use super::{OrganizationId, User, UserId};
use crate::db::schema::sends;
use diesel::prelude::*;
use id::SendId; use id::SendId;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = sends)] #[diesel(table_name = sends)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -38,7 +39,6 @@ db_object! {
pub disabled: bool, pub disabled: bool,
pub hide_email: Option<bool>, pub hide_email: Option<bool>,
} }
}
#[derive(Copy, Clone, PartialEq, Eq, num_derive::FromPrimitive)] #[derive(Copy, Clone, PartialEq, Eq, num_derive::FromPrimitive)]
pub enum SendType { pub enum SendType {
@ -103,7 +103,7 @@ impl Send {
} }
} }
pub async fn creator_identifier(&self, conn: &mut DbConn) -> Option<String> { pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email { if let Some(hide_email) = self.hide_email {
if hide_email { if hide_email {
return None; return None;
@ -155,7 +155,7 @@ impl Send {
}) })
} }
pub async fn to_json_access(&self, conn: &mut DbConn) -> Value { pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date; use crate::util::format_date;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default(); let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
@ -187,14 +187,14 @@ use crate::error::MapResult;
use crate::util::NumberOrString; use crate::util::NumberOrString;
impl Send { impl Send {
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
self.revision_date = Utc::now().naive_utc(); self.revision_date = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(sends::table) match diesel::replace_into(sends::table)
.values(SendDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -202,7 +202,7 @@ impl Send {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(sends::table) diesel::update(sends::table)
.filter(sends::uuid.eq(&self.uuid)) .filter(sends::uuid.eq(&self.uuid))
.set(SendDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving send") .map_res("Error saving send")
} }
@ -210,19 +210,18 @@ impl Send {
}.map_res("Error saving send") }.map_res("Error saving send")
} }
postgresql { postgresql {
let value = SendDb::to_db(self);
diesel::insert_into(sends::table) diesel::insert_into(sends::table)
.values(&value) .values(&*self)
.on_conflict(sends::uuid) .on_conflict(sends::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving send") .map_res("Error saving send")
} }
} }
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await; self.update_users_revision(conn).await;
if self.atype == SendType::File as i32 { if self.atype == SendType::File as i32 {
@ -238,13 +237,13 @@ impl Send {
} }
/// Purge all sends that are past their deletion date. /// Purge all sends that are past their deletion date.
pub async fn purge(conn: &mut DbConn) { pub async fn purge(conn: &DbConn) {
for send in Self::find_by_past_deletion_date(conn).await { for send in Self::find_by_past_deletion_date(conn).await {
send.delete(conn).await.ok(); send.delete(conn).await.ok();
} }
} }
pub async fn update_users_revision(&self, conn: &mut DbConn) -> Vec<UserId> { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new(); let mut user_uuids = Vec::new();
match &self.user_uuid { match &self.user_uuid {
Some(user_uuid) => { Some(user_uuid) => {
@ -258,14 +257,14 @@ impl Send {
user_uuids user_uuids
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
for send in Self::find_by_user(user_uuid, conn).await { for send in Self::find_by_user(user_uuid, conn).await {
send.delete(conn).await?; send.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub async fn find_by_access_id(access_id: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid; use uuid::Uuid;
@ -281,36 +280,35 @@ impl Send {
Self::find_by_uuid(&uuid, conn).await Self::find_by_uuid(&uuid, conn).await
} }
pub async fn find_by_uuid(uuid: &SendId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
sends::table sends::table
.filter(sends::uuid.eq(uuid)) .filter(sends::uuid.eq(uuid))
.first::<SendDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
sends::table sends::table
.filter(sends::uuid.eq(uuid)) .filter(sends::uuid.eq(uuid))
.filter(sends::user_uuid.eq(user_uuid)) .filter(sends::user_uuid.eq(user_uuid))
.first::<SendDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
sends::table sends::table
.filter(sends::user_uuid.eq(user_uuid)) .filter(sends::user_uuid.eq(user_uuid))
.load::<SendDb>(conn).expect("Error loading sends").from_db() .load::<Self>(conn)
.expect("Error loading sends")
}} }}
} }
pub async fn size_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Option<i64> { pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> {
let sends = Self::find_by_user(user_uuid, conn).await; let sends = Self::find_by_user(user_uuid, conn).await;
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
@ -333,20 +331,22 @@ impl Send {
Some(total) Some(total)
} }
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
sends::table sends::table
.filter(sends::organization_uuid.eq(org_uuid)) .filter(sends::organization_uuid.eq(org_uuid))
.load::<SendDb>(conn).expect("Error loading sends").from_db() .load::<Self>(conn)
.expect("Error loading sends")
}} }}
} }
pub async fn find_by_past_deletion_date(conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
db_run! { conn: { db_run! { conn: {
sends::table sends::table
.filter(sends::deletion_date.lt(now)) .filter(sends::deletion_date.lt(now))
.load::<SendDb>(conn).expect("Error loading sends").from_db() .load::<Self>(conn)
.expect("Error loading sends")
}} }}
} }
} }

16
src/db/models/sso_nonce.rs

@ -1,11 +1,12 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use crate::api::EmptyResult; use crate::api::EmptyResult;
use crate::db::schema::sso_nonce;
use crate::db::{DbConn, DbPool}; use crate::db::{DbConn, DbPool};
use crate::error::MapResult; use crate::error::MapResult;
use crate::sso::{OIDCState, NONCE_EXPIRATION}; use crate::sso::{OIDCState, NONCE_EXPIRATION};
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable)] #[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = sso_nonce)] #[diesel(table_name = sso_nonce)]
#[diesel(primary_key(state))] #[diesel(primary_key(state))]
@ -16,7 +17,6 @@ db_object! {
pub redirect_uri: String, pub redirect_uri: String,
pub created_at: NaiveDateTime, pub created_at: NaiveDateTime,
} }
}
/// Local methods /// Local methods
impl SsoNonce { impl SsoNonce {
@ -35,25 +35,24 @@ impl SsoNonce {
/// Database methods /// Database methods
impl SsoNonce { impl SsoNonce {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
diesel::replace_into(sso_nonce::table) diesel::replace_into(sso_nonce::table)
.values(SsoNonceDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving SSO nonce") .map_res("Error saving SSO nonce")
} }
postgresql { postgresql {
let value = SsoNonceDb::to_db(self);
diesel::insert_into(sso_nonce::table) diesel::insert_into(sso_nonce::table)
.values(&value) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving SSO nonce") .map_res("Error saving SSO nonce")
} }
} }
} }
pub async fn delete(state: &OIDCState, conn: &mut DbConn) -> EmptyResult { pub async fn delete(state: &OIDCState, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state))) diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state)))
.execute(conn) .execute(conn)
@ -67,9 +66,8 @@ impl SsoNonce {
sso_nonce::table sso_nonce::table
.filter(sso_nonce::state.eq(state)) .filter(sso_nonce::state.eq(state))
.filter(sso_nonce::created_at.ge(oldest)) .filter(sso_nonce::created_at.ge(oldest))
.first::<SsoNonceDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }

39
src/db/models/two_factor.rs

@ -1,12 +1,13 @@
use super::UserId; use super::UserId;
use crate::api::core::two_factor::webauthn::WebauthnRegistration; use crate::api::core::two_factor::webauthn::WebauthnRegistration;
use crate::db::schema::twofactor;
use crate::{api::EmptyResult, db::DbConn, error::MapResult}; use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use webauthn_rs::prelude::{Credential, ParsedAttestation}; use webauthn_rs::prelude::{Credential, ParsedAttestation};
use webauthn_rs_core::proto::CredentialV3; use webauthn_rs_core::proto::CredentialV3;
use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions}; use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions};
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor)] #[diesel(table_name = twofactor)]
#[diesel(primary_key(uuid))] #[diesel(primary_key(uuid))]
@ -18,7 +19,6 @@ db_object! {
pub data: String, pub data: String,
pub last_used: i64, pub last_used: i64,
} }
}
#[allow(dead_code)] #[allow(dead_code)]
#[derive(num_derive::FromPrimitive)] #[derive(num_derive::FromPrimitive)]
@ -76,11 +76,11 @@ impl TwoFactor {
/// Database methods /// Database methods
impl TwoFactor { impl TwoFactor {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(twofactor::table) match diesel::replace_into(twofactor::table)
.values(TwoFactorDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -88,7 +88,7 @@ impl TwoFactor {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(twofactor::table) diesel::update(twofactor::table)
.filter(twofactor::uuid.eq(&self.uuid)) .filter(twofactor::uuid.eq(&self.uuid))
.set(TwoFactorDb::to_db(self)) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving twofactor") .map_res("Error saving twofactor")
} }
@ -96,7 +96,6 @@ impl TwoFactor {
}.map_res("Error saving twofactor") }.map_res("Error saving twofactor")
} }
postgresql { postgresql {
let value = TwoFactorDb::to_db(self);
// We need to make sure we're not going to violate the unique constraint on user_uuid and atype. // We need to make sure we're not going to violate the unique constraint on user_uuid and atype.
// This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does // This happens automatically on other DBMS backends due to replace_into(). PostgreSQL does
// not support multiple constraints on ON CONFLICT clauses. // not support multiple constraints on ON CONFLICT clauses.
@ -105,17 +104,17 @@ impl TwoFactor {
.map_res("Error deleting twofactor for insert")?; .map_res("Error deleting twofactor for insert")?;
diesel::insert_into(twofactor::table) diesel::insert_into(twofactor::table)
.values(&value) .values(self)
.on_conflict(twofactor::uuid) .on_conflict(twofactor::uuid)
.do_update() .do_update()
.set(&value) .set(self)
.execute(conn) .execute(conn)
.map_res("Error saving twofactor") .map_res("Error saving twofactor")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid))) diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
@ -123,29 +122,27 @@ impl TwoFactor {
}} }}
} }
pub async fn find_by_user(user_uuid: &UserId, conn: &mut DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.lt(1000)) // Filter implementation types .filter(twofactor::atype.lt(1000)) // Filter implementation types
.load::<TwoFactorDb>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
.from_db()
}} }}
} }
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.eq(atype)) .filter(twofactor::atype.eq(atype))
.first::<TwoFactorDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
@ -153,13 +150,12 @@ impl TwoFactor {
}} }}
} }
pub async fn migrate_u2f_to_webauthn(conn: &mut DbConn) -> EmptyResult { pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
let u2f_factors = db_run! { conn: { let u2f_factors = db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32)) .filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<TwoFactorDb>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
.from_db()
}}; }};
use crate::api::core::two_factor::webauthn::U2FRegistration; use crate::api::core::two_factor::webauthn::U2FRegistration;
@ -231,13 +227,12 @@ impl TwoFactor {
Ok(()) Ok(())
} }
pub async fn migrate_credential_to_passkey(conn: &mut DbConn) -> EmptyResult { pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult {
let webauthn_factors = db_run! { conn: { let webauthn_factors = db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32)) .filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<TwoFactorDb>(conn) .load::<Self>(conn)
.expect("Error loading twofactor") .expect("Error loading twofactor")
.from_db()
}}; }};
for webauthn_factor in webauthn_factors { for webauthn_factor in webauthn_factors {

44
src/db/models/two_factor_duo_context.rs

@ -1,8 +1,9 @@
use chrono::Utc; use chrono::Utc;
use crate::db::schema::twofactor_duo_ctx;
use crate::{api::EmptyResult, db::DbConn, error::MapResult}; use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_duo_ctx)] #[diesel(table_name = twofactor_duo_ctx)]
#[diesel(primary_key(state))] #[diesel(primary_key(state))]
@ -12,22 +13,18 @@ db_object! {
pub nonce: String, pub nonce: String,
pub exp: i64, pub exp: i64,
} }
}
impl TwoFactorDuoContext { impl TwoFactorDuoContext {
pub async fn find_by_state(state: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> {
db_run! { db_run! { conn: {
conn: {
twofactor_duo_ctx::table twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(state)) .filter(twofactor_duo_ctx::state.eq(state))
.first::<TwoFactorDuoContextDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db() }}
}
}
} }
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &mut DbConn) -> EmptyResult { pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult {
// A saved context should never be changed, only created or deleted. // A saved context should never be changed, only created or deleted.
let exists = Self::find_by_state(state, conn).await; let exists = Self::find_by_state(state, conn).await;
if exists.is_some() { if exists.is_some() {
@ -36,8 +33,7 @@ impl TwoFactorDuoContext {
let exp = Utc::now().timestamp() + ttl; let exp = Utc::now().timestamp() + ttl;
db_run! { db_run! { conn: {
conn: {
diesel::insert_into(twofactor_duo_ctx::table) diesel::insert_into(twofactor_duo_ctx::table)
.values(( .values((
twofactor_duo_ctx::state.eq(state), twofactor_duo_ctx::state.eq(state),
@ -47,36 +43,30 @@ impl TwoFactorDuoContext {
)) ))
.execute(conn) .execute(conn)
.map_res("Error saving context to twofactor_duo_ctx") .map_res("Error saving context to twofactor_duo_ctx")
} }}
}
} }
pub async fn find_expired(conn: &mut DbConn) -> Vec<Self> { pub async fn find_expired(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
db_run! { db_run! { conn: {
conn: {
twofactor_duo_ctx::table twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::exp.lt(now)) .filter(twofactor_duo_ctx::exp.lt(now))
.load::<TwoFactorDuoContextDb>(conn) .load::<Self>(conn)
.expect("Error finding expired contexts in twofactor_duo_ctx") .expect("Error finding expired contexts in twofactor_duo_ctx")
.from_db() }}
}
}
} }
pub async fn delete(&self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { db_run! { conn: {
conn: {
diesel::delete( diesel::delete(
twofactor_duo_ctx::table twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(&self.state))) .filter(twofactor_duo_ctx::state.eq(&self.state)))
.execute(conn) .execute(conn)
.map_res("Error deleting from twofactor_duo_ctx") .map_res("Error deleting from twofactor_duo_ctx")
} }}
}
} }
pub async fn purge_expired_duo_contexts(conn: &mut DbConn) { pub async fn purge_expired_duo_contexts(conn: &DbConn) {
for context in Self::find_expired(conn).await { for context in Self::find_expired(conn).await {
context.delete(conn).await.ok(); context.delete(conn).await.ok();
} }

32
src/db/models/two_factor_incomplete.rs

@ -1,5 +1,6 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use crate::db::schema::twofactor_incomplete;
use crate::{ use crate::{
api::EmptyResult, api::EmptyResult,
auth::ClientIp, auth::ClientIp,
@ -10,8 +11,8 @@ use crate::{
error::MapResult, error::MapResult,
CONFIG, CONFIG,
}; };
use diesel::prelude::*;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_incomplete)] #[diesel(table_name = twofactor_incomplete)]
#[diesel(primary_key(user_uuid, device_uuid))] #[diesel(primary_key(user_uuid, device_uuid))]
@ -26,7 +27,6 @@ db_object! {
pub login_time: NaiveDateTime, pub login_time: NaiveDateTime,
pub ip_address: String, pub ip_address: String,
} }
}
impl TwoFactorIncomplete { impl TwoFactorIncomplete {
pub async fn mark_incomplete( pub async fn mark_incomplete(
@ -35,7 +35,7 @@ impl TwoFactorIncomplete {
device_name: &str, device_name: &str,
device_type: i32, device_type: i32,
ip: &ClientIp, ip: &ClientIp,
conn: &mut DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return Ok(()); return Ok(());
@ -64,7 +64,7 @@ impl TwoFactorIncomplete {
}} }}
} }
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &mut DbConn) -> EmptyResult { pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return Ok(()); return Ok(());
} }
@ -72,40 +72,30 @@ impl TwoFactorIncomplete {
Self::delete_by_user_and_device(user_uuid, device_uuid, conn).await Self::delete_by_user_and_device(user_uuid, device_uuid, conn).await
} }
pub async fn find_by_user_and_device( pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
user_uuid: &UserId,
device_uuid: &DeviceId,
conn: &mut DbConn,
) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)) .filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.first::<TwoFactorIncompleteDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &mut DbConn) -> Vec<Self> { pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt)) .filter(twofactor_incomplete::login_time.lt(dt))
.load::<TwoFactorIncompleteDb>(conn) .load::<Self>(conn)
.expect("Error loading twofactor_incomplete") .expect("Error loading twofactor_incomplete")
.from_db()
}} }}
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
Self::delete_by_user_and_device(&self.user_uuid, &self.device_uuid, conn).await Self::delete_by_user_and_device(&self.user_uuid, &self.device_uuid, conn).await
} }
pub async fn delete_by_user_and_device( pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
user_uuid: &UserId,
device_uuid: &DeviceId,
conn: &mut DbConn,
) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor_incomplete::table diesel::delete(twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
@ -115,7 +105,7 @@ impl TwoFactorIncomplete {
}} }}
} }
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid))) diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)

89
src/db/models/user.rs

@ -1,5 +1,7 @@
use crate::db::schema::{devices, invitations, sso_users, users};
use chrono::{NaiveDateTime, TimeDelta, Utc}; use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From}; use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value; use serde_json::Value;
use super::{ use super::{
@ -17,7 +19,6 @@ use crate::{
}; };
use macros::UuidFromParam; use macros::UuidFromParam;
db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)] #[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = users)] #[diesel(table_name = users)]
#[diesel(treat_none_as_null = true)] #[diesel(treat_none_as_null = true)]
@ -81,7 +82,6 @@ db_object! {
pub user_uuid: UserId, pub user_uuid: UserId,
pub identifier: OIDCIdentifier, pub identifier: OIDCIdentifier,
} }
}
pub enum UserKdfType { pub enum UserKdfType {
Pbkdf2 = 0, Pbkdf2 = 0,
@ -236,7 +236,7 @@ impl User {
/// Database methods /// Database methods
impl User { impl User {
pub async fn to_json(&self, conn: &mut DbConn) -> Value { pub async fn to_json(&self, conn: &DbConn) -> Value {
let mut orgs_json = Vec::new(); let mut orgs_json = Vec::new();
for c in Membership::find_confirmed_by_user(&self.uuid, conn).await { for c in Membership::find_confirmed_by_user(&self.uuid, conn).await {
orgs_json.push(c.to_json(conn).await); orgs_json.push(c.to_json(conn).await);
@ -275,7 +275,7 @@ impl User {
}) })
} }
pub async fn save(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.email) { if !crate::util::is_valid_email(&self.email) {
err!(format!("User email {} is not a valid email address", self.email)) err!(format!("User email {} is not a valid email address", self.email))
} }
@ -285,7 +285,7 @@ impl User {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(users::table) match diesel::replace_into(users::table)
.values(UserDb::to_db(self)) .values(&*self)
.execute(conn) .execute(conn)
{ {
Ok(_) => Ok(()), Ok(_) => Ok(()),
@ -293,7 +293,7 @@ impl User {
Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => { Err(diesel::result::Error::DatabaseError(diesel::result::DatabaseErrorKind::ForeignKeyViolation, _)) => {
diesel::update(users::table) diesel::update(users::table)
.filter(users::uuid.eq(&self.uuid)) .filter(users::uuid.eq(&self.uuid))
.set(UserDb::to_db(self)) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving user") .map_res("Error saving user")
} }
@ -301,19 +301,18 @@ impl User {
}.map_res("Error saving user") }.map_res("Error saving user")
} }
postgresql { postgresql {
let value = UserDb::to_db(self);
diesel::insert_into(users::table) // Insert or update diesel::insert_into(users::table) // Insert or update
.values(&value) .values(&*self)
.on_conflict(users::uuid) .on_conflict(users::uuid)
.do_update() .do_update()
.set(&value) .set(&*self)
.execute(conn) .execute(conn)
.map_res("Error saving user") .map_res("Error saving user")
} }
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
for member in Membership::find_confirmed_by_user(&self.uuid, conn).await { for member in Membership::find_confirmed_by_user(&self.uuid, conn).await {
if member.atype == MembershipType::Owner if member.atype == MembershipType::Owner
&& Membership::count_confirmed_by_org_and_type(&member.org_uuid, MembershipType::Owner, conn).await <= 1 && Membership::count_confirmed_by_org_and_type(&member.org_uuid, MembershipType::Owner, conn).await <= 1
@ -341,13 +340,13 @@ impl User {
}} }}
} }
pub async fn update_uuid_revision(uuid: &UserId, conn: &mut DbConn) { pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await { if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}"); warn!("Failed to update revision for {uuid}: {e:#?}");
} }
} }
pub async fn update_all_revisions(conn: &mut DbConn) -> EmptyResult { pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc(); let updated_at = Utc::now().naive_utc();
db_run! { conn: { db_run! { conn: {
@ -360,13 +359,13 @@ impl User {
}} }}
} }
pub async fn update_revision(&mut self, conn: &mut DbConn) -> EmptyResult { pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
Self::_update_revision(&self.uuid, &self.updated_at, conn).await Self::_update_revision(&self.uuid, &self.updated_at, conn).await
} }
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &mut DbConn) -> EmptyResult { async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
retry(|| { retry(|| {
diesel::update(users::table.filter(users::uuid.eq(uuid))) diesel::update(users::table.filter(users::uuid.eq(uuid)))
@ -377,49 +376,49 @@ impl User {
}} }}
} }
pub async fn find_by_mail(mail: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { db_run! { conn: {
users::table users::table
.filter(users::email.eq(lower_mail)) .filter(users::email.eq(lower_mail))
.first::<UserDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn find_by_uuid(uuid: &UserId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users::table.filter(users::uuid.eq(uuid)).first::<UserDb>(conn).ok().from_db() users::table
.filter(users::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}} }}
} }
pub async fn find_by_device_id(device_uuid: &DeviceId, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_device_id(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users::table users::table
.inner_join(devices::table.on(devices::user_uuid.eq(users::uuid))) .inner_join(devices::table.on(devices::user_uuid.eq(users::uuid)))
.filter(devices::uuid.eq(device_uuid)) .filter(devices::uuid.eq(device_uuid))
.select(users::all_columns) .select(users::all_columns)
.first::<UserDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn get_all(conn: &mut DbConn) -> Vec<(User, Option<SsoUser>)> { pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
db_run! { conn: { db_run! { conn: {
users::table users::table
.left_join(sso_users::table) .left_join(sso_users::table)
.select(<(UserDb, Option<SsoUserDb>)>::as_select()) .select(<(Self, Option<SsoUser>)>::as_select())
.load(conn) .load(conn)
.expect("Error loading groups for user") .expect("Error loading groups for user")
.into_iter() .into_iter()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
.collect() .collect()
}} }}
} }
pub async fn last_active(&self, conn: &mut DbConn) -> Option<NaiveDateTime> { pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
match Device::find_latest_active_by_user(&self.uuid, conn).await { match Device::find_latest_active_by_user(&self.uuid, conn).await {
Some(device) => Some(device.updated_at), Some(device) => Some(device.updated_at),
None => None, None => None,
@ -435,7 +434,7 @@ impl Invitation {
} }
} }
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
if !crate::util::is_valid_email(&self.email) { if !crate::util::is_valid_email(&self.email) {
err!(format!("Invitation email {} is not a valid email address", self.email)) err!(format!("Invitation email {} is not a valid email address", self.email))
} }
@ -445,13 +444,13 @@ impl Invitation {
// Not checking for ForeignKey Constraints here // Not checking for ForeignKey Constraints here
// Table invitations does not have any ForeignKey Constraints. // Table invitations does not have any ForeignKey Constraints.
diesel::replace_into(invitations::table) diesel::replace_into(invitations::table)
.values(InvitationDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving invitation") .map_res("Error saving invitation")
} }
postgresql { postgresql {
diesel::insert_into(invitations::table) diesel::insert_into(invitations::table)
.values(InvitationDb::to_db(self)) .values(self)
.on_conflict(invitations::email) .on_conflict(invitations::email)
.do_nothing() .do_nothing()
.execute(conn) .execute(conn)
@ -460,7 +459,7 @@ impl Invitation {
} }
} }
pub async fn delete(self, conn: &mut DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email))) diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn) .execute(conn)
@ -468,18 +467,17 @@ impl Invitation {
}} }}
} }
pub async fn find_by_mail(mail: &str, conn: &mut DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { db_run! { conn: {
invitations::table invitations::table
.filter(invitations::email.eq(lower_mail)) .filter(invitations::email.eq(lower_mail))
.first::<InvitationDb>(conn) .first::<Self>(conn)
.ok() .ok()
.from_db()
}} }}
} }
pub async fn take(mail: &str, conn: &mut DbConn) -> bool { pub async fn take(mail: &str, conn: &DbConn) -> bool {
match Self::find_by_mail(mail, conn).await { match Self::find_by_mail(mail, conn).await {
Some(invitation) => invitation.delete(conn).await.is_ok(), Some(invitation) => invitation.delete(conn).await.is_ok(),
None => false, None => false,
@ -508,51 +506,48 @@ impl Invitation {
pub struct UserId(String); pub struct UserId(String);
impl SsoUser { impl SsoUser {
pub async fn save(&self, conn: &mut DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
diesel::replace_into(sso_users::table) diesel::replace_into(sso_users::table)
.values(SsoUserDb::to_db(self)) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving SSO user") .map_res("Error saving SSO user")
} }
postgresql { postgresql {
let value = SsoUserDb::to_db(self);
diesel::insert_into(sso_users::table) diesel::insert_into(sso_users::table)
.values(&value) .values(self)
.execute(conn) .execute(conn)
.map_res("Error saving SSO user") .map_res("Error saving SSO user")
} }
} }
} }
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, SsoUser)> { pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> {
db_run! { conn: { db_run! { conn: {
users::table users::table
.inner_join(sso_users::table) .inner_join(sso_users::table)
.select(<(UserDb, SsoUserDb)>::as_select()) .select(<(User, Self)>::as_select())
.filter(sso_users::identifier.eq(identifier)) .filter(sso_users::identifier.eq(identifier))
.first::<(UserDb, SsoUserDb)>(conn) .first::<(User, Self)>(conn)
.ok() .ok()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
}} }}
} }
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<SsoUser>)> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! { conn: { db_run! { conn: {
users::table users::table
.left_join(sso_users::table) .left_join(sso_users::table)
.select(<(UserDb, Option<SsoUserDb>)>::as_select()) .select(<(User, Option<Self>)>::as_select())
.filter(users::email.eq(lower_mail)) .filter(users::email.eq(lower_mail))
.first::<(UserDb, Option<SsoUserDb>)>(conn) .first::<(User, Option<Self>)>(conn)
.ok() .ok()
.map(|(user, sso_user)| { (user.from_db(), sso_user.from_db()) })
}} }}
} }
pub async fn delete(user_uuid: &UserId, conn: &mut DbConn) -> EmptyResult { pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid))) diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)

57
src/db/query_logger.rs

@ -0,0 +1,57 @@
use dashmap::DashMap;
use diesel::connection::{Instrumentation, InstrumentationEvent};
use std::{
sync::{Arc, LazyLock},
thread,
time::Instant,
};
pub static QUERY_PERF_TRACKER: LazyLock<Arc<DashMap<(thread::ThreadId, String), Instant>>> =
LazyLock::new(|| Arc::new(DashMap::new()));
pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
Some(Box::new(|event: InstrumentationEvent<'_>| match event {
InstrumentationEvent::StartEstablishConnection {
url,
..
} => {
debug!("Establishing connection: {url}")
}
InstrumentationEvent::FinishEstablishConnection {
url,
error,
..
} => {
if let Some(e) = error {
error!("Error during establishing a connection with {url}: {e:?}")
} else {
debug!("Connection established: {url}")
}
}
InstrumentationEvent::StartQuery {
query,
..
} => {
let query_string = format!("{query:?}");
let start = Instant::now();
QUERY_PERF_TRACKER.insert((thread::current().id(), query_string), start);
}
InstrumentationEvent::FinishQuery {
query,
..
} => {
let query_string = format!("{query:?}");
if let Some((_, start)) = QUERY_PERF_TRACKER.remove(&(thread::current().id(), query_string.clone())) {
let duration = start.elapsed();
if duration.as_secs() >= 5 {
warn!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else if duration.as_secs() >= 1 {
info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else {
debug!("QUERY [{:?}]: {}", duration, query_string);
}
}
}
_ => {}
}))
}

0
src/db/schemas/postgresql/schema.rs → src/db/schema.rs

395
src/db/schemas/mysql/schema.rs

@ -1,395 +0,0 @@
table! {
attachments (id) {
id -> Text,
cipher_uuid -> Text,
file_name -> Text,
file_size -> BigInt,
akey -> Nullable<Text>,
}
}
table! {
ciphers (uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
key -> Nullable<Text>,
atype -> Integer,
name -> Text,
notes -> Nullable<Text>,
fields -> Nullable<Text>,
data -> Text,
password_history -> Nullable<Text>,
deleted_at -> Nullable<Datetime>,
reprompt -> Nullable<Integer>,
}
}
table! {
ciphers_collections (cipher_uuid, collection_uuid) {
cipher_uuid -> Text,
collection_uuid -> Text,
}
}
table! {
collections (uuid) {
uuid -> Text,
org_uuid -> Text,
name -> Text,
external_id -> Nullable<Text>,
}
}
table! {
devices (uuid, user_uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Text,
name -> Text,
atype -> Integer,
push_uuid -> Nullable<Text>,
push_token -> Nullable<Text>,
refresh_token -> Text,
twofactor_remember -> Nullable<Text>,
}
}
table! {
event (uuid) {
uuid -> Varchar,
event_type -> Integer,
user_uuid -> Nullable<Varchar>,
org_uuid -> Nullable<Varchar>,
cipher_uuid -> Nullable<Varchar>,
collection_uuid -> Nullable<Varchar>,
group_uuid -> Nullable<Varchar>,
org_user_uuid -> Nullable<Varchar>,
act_user_uuid -> Nullable<Varchar>,
device_type -> Nullable<Integer>,
ip_address -> Nullable<Text>,
event_date -> Timestamp,
policy_uuid -> Nullable<Varchar>,
provider_uuid -> Nullable<Varchar>,
provider_user_uuid -> Nullable<Varchar>,
provider_org_uuid -> Nullable<Varchar>,
}
}
table! {
favorites (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
}
}
table! {
folders (uuid) {
uuid -> Text,
created_at -> Datetime,
updated_at -> Datetime,
user_uuid -> Text,
name -> Text,
}
}
table! {
folders_ciphers (cipher_uuid, folder_uuid) {
cipher_uuid -> Text,
folder_uuid -> Text,
}
}
table! {
invitations (email) {
email -> Text,
}
}
table! {
org_policies (uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
}
}
table! {
organizations (uuid) {
uuid -> Text,
name -> Text,
billing_email -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
}
}
table! {
sends (uuid) {
uuid -> Text,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
name -> Text,
notes -> Nullable<Text>,
atype -> Integer,
data -> Text,
akey -> Text,
password_hash -> Nullable<Binary>,
password_salt -> Nullable<Binary>,
password_iter -> Nullable<Integer>,
max_access_count -> Nullable<Integer>,
access_count -> Integer,
creation_date -> Datetime,
revision_date -> Datetime,
expiration_date -> Nullable<Datetime>,
deletion_date -> Datetime,
disabled -> Bool,
hide_email -> Nullable<Bool>,
}
}
table! {
twofactor (uuid) {
uuid -> Text,
user_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
last_used -> BigInt,
}
}
table! {
twofactor_incomplete (user_uuid, device_uuid) {
user_uuid -> Text,
device_uuid -> Text,
device_name -> Text,
device_type -> Integer,
login_time -> Timestamp,
ip_address -> Text,
}
}
table! {
twofactor_duo_ctx (state) {
state -> Text,
user_email -> Text,
nonce -> Text,
exp -> BigInt,
}
}
table! {
users (uuid) {
uuid -> Text,
enabled -> Bool,
created_at -> Datetime,
updated_at -> Datetime,
verified_at -> Nullable<Datetime>,
last_verifying_at -> Nullable<Datetime>,
login_verify_count -> Integer,
email -> Text,
email_new -> Nullable<Text>,
email_new_token -> Nullable<Text>,
name -> Text,
password_hash -> Binary,
salt -> Binary,
password_iterations -> Integer,
password_hint -> Nullable<Text>,
akey -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
totp_secret -> Nullable<Text>,
totp_recover -> Nullable<Text>,
security_stamp -> Text,
stamp_exception -> Nullable<Text>,
equivalent_domains -> Text,
excluded_globals -> Text,
client_kdf_type -> Integer,
client_kdf_iter -> Integer,
client_kdf_memory -> Nullable<Integer>,
client_kdf_parallelism -> Nullable<Integer>,
api_key -> Nullable<Text>,
avatar_color -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
users_collections (user_uuid, collection_uuid) {
user_uuid -> Text,
collection_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
users_organizations (uuid) {
uuid -> Text,
user_uuid -> Text,
org_uuid -> Text,
invited_by_email -> Nullable<Text>,
access_all -> Bool,
akey -> Text,
status -> Integer,
atype -> Integer,
reset_password_key -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
organization_api_key (uuid, org_uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
api_key -> Text,
revision_date -> Timestamp,
}
}
table! {
sso_nonce (state) {
state -> Text,
nonce -> Text,
verifier -> Nullable<Text>,
redirect_uri -> Text,
created_at -> Timestamp,
}
}
table! {
sso_users (user_uuid) {
user_uuid -> Text,
identifier -> Text,
}
}
table! {
emergency_access (uuid) {
uuid -> Text,
grantor_uuid -> Text,
grantee_uuid -> Nullable<Text>,
email -> Nullable<Text>,
key_encrypted -> Nullable<Text>,
atype -> Integer,
status -> Integer,
wait_time_days -> Integer,
recovery_initiated_at -> Nullable<Timestamp>,
last_notification_at -> Nullable<Timestamp>,
updated_at -> Timestamp,
created_at -> Timestamp,
}
}
table! {
groups (uuid) {
uuid -> Text,
organizations_uuid -> Text,
name -> Text,
access_all -> Bool,
external_id -> Nullable<Text>,
creation_date -> Timestamp,
revision_date -> Timestamp,
}
}
table! {
groups_users (groups_uuid, users_organizations_uuid) {
groups_uuid -> Text,
users_organizations_uuid -> Text,
}
}
table! {
collections_groups (collections_uuid, groups_uuid) {
collections_uuid -> Text,
groups_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
auth_requests (uuid) {
uuid -> Text,
user_uuid -> Text,
organization_uuid -> Nullable<Text>,
request_device_identifier -> Text,
device_type -> Integer,
request_ip -> Text,
response_device_id -> Nullable<Text>,
access_code -> Text,
public_key -> Text,
enc_key -> Nullable<Text>,
master_password_hash -> Nullable<Text>,
approved -> Nullable<Bool>,
creation_date -> Timestamp,
response_date -> Nullable<Timestamp>,
authentication_date -> Nullable<Timestamp>,
}
}
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
joinable!(ciphers_collections -> ciphers (cipher_uuid));
joinable!(ciphers_collections -> collections (collection_uuid));
joinable!(collections -> organizations (org_uuid));
joinable!(devices -> users (user_uuid));
joinable!(folders -> users (user_uuid));
joinable!(folders_ciphers -> ciphers (cipher_uuid));
joinable!(folders_ciphers -> folders (folder_uuid));
joinable!(org_policies -> organizations (org_uuid));
joinable!(sends -> organizations (organization_uuid));
joinable!(sends -> users (user_uuid));
joinable!(twofactor -> users (user_uuid));
joinable!(users_collections -> collections (collection_uuid));
joinable!(users_collections -> users (user_uuid));
joinable!(users_organizations -> organizations (org_uuid));
joinable!(users_organizations -> users (user_uuid));
joinable!(users_organizations -> ciphers (org_uuid));
joinable!(organization_api_key -> organizations (org_uuid));
joinable!(emergency_access -> users (grantor_uuid));
joinable!(groups -> organizations (organizations_uuid));
joinable!(groups_users -> users_organizations (users_organizations_uuid));
joinable!(groups_users -> groups (groups_uuid));
joinable!(collections_groups -> collections (collections_uuid));
joinable!(collections_groups -> groups (groups_uuid));
joinable!(event -> users_organizations (uuid));
joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
attachments,
ciphers,
ciphers_collections,
collections,
devices,
folders,
folders_ciphers,
invitations,
org_policies,
organizations,
sends,
sso_users,
twofactor,
users,
users_collections,
users_organizations,
organization_api_key,
emergency_access,
groups,
groups_users,
collections_groups,
event,
auth_requests,
);

395
src/db/schemas/sqlite/schema.rs

@ -1,395 +0,0 @@
table! {
attachments (id) {
id -> Text,
cipher_uuid -> Text,
file_name -> Text,
file_size -> BigInt,
akey -> Nullable<Text>,
}
}
table! {
ciphers (uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
key -> Nullable<Text>,
atype -> Integer,
name -> Text,
notes -> Nullable<Text>,
fields -> Nullable<Text>,
data -> Text,
password_history -> Nullable<Text>,
deleted_at -> Nullable<Timestamp>,
reprompt -> Nullable<Integer>,
}
}
table! {
ciphers_collections (cipher_uuid, collection_uuid) {
cipher_uuid -> Text,
collection_uuid -> Text,
}
}
table! {
collections (uuid) {
uuid -> Text,
org_uuid -> Text,
name -> Text,
external_id -> Nullable<Text>,
}
}
table! {
devices (uuid, user_uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Text,
name -> Text,
atype -> Integer,
push_uuid -> Nullable<Text>,
push_token -> Nullable<Text>,
refresh_token -> Text,
twofactor_remember -> Nullable<Text>,
}
}
table! {
event (uuid) {
uuid -> Text,
event_type -> Integer,
user_uuid -> Nullable<Text>,
org_uuid -> Nullable<Text>,
cipher_uuid -> Nullable<Text>,
collection_uuid -> Nullable<Text>,
group_uuid -> Nullable<Text>,
org_user_uuid -> Nullable<Text>,
act_user_uuid -> Nullable<Text>,
device_type -> Nullable<Integer>,
ip_address -> Nullable<Text>,
event_date -> Timestamp,
policy_uuid -> Nullable<Text>,
provider_uuid -> Nullable<Text>,
provider_user_uuid -> Nullable<Text>,
provider_org_uuid -> Nullable<Text>,
}
}
table! {
favorites (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
}
}
table! {
folders (uuid) {
uuid -> Text,
created_at -> Timestamp,
updated_at -> Timestamp,
user_uuid -> Text,
name -> Text,
}
}
table! {
folders_ciphers (cipher_uuid, folder_uuid) {
cipher_uuid -> Text,
folder_uuid -> Text,
}
}
table! {
invitations (email) {
email -> Text,
}
}
table! {
org_policies (uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
}
}
table! {
organizations (uuid) {
uuid -> Text,
name -> Text,
billing_email -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
}
}
table! {
sends (uuid) {
uuid -> Text,
user_uuid -> Nullable<Text>,
organization_uuid -> Nullable<Text>,
name -> Text,
notes -> Nullable<Text>,
atype -> Integer,
data -> Text,
akey -> Text,
password_hash -> Nullable<Binary>,
password_salt -> Nullable<Binary>,
password_iter -> Nullable<Integer>,
max_access_count -> Nullable<Integer>,
access_count -> Integer,
creation_date -> Timestamp,
revision_date -> Timestamp,
expiration_date -> Nullable<Timestamp>,
deletion_date -> Timestamp,
disabled -> Bool,
hide_email -> Nullable<Bool>,
}
}
table! {
twofactor (uuid) {
uuid -> Text,
user_uuid -> Text,
atype -> Integer,
enabled -> Bool,
data -> Text,
last_used -> BigInt,
}
}
table! {
twofactor_incomplete (user_uuid, device_uuid) {
user_uuid -> Text,
device_uuid -> Text,
device_name -> Text,
device_type -> Integer,
login_time -> Timestamp,
ip_address -> Text,
}
}
table! {
twofactor_duo_ctx (state) {
state -> Text,
user_email -> Text,
nonce -> Text,
exp -> BigInt,
}
}
table! {
users (uuid) {
uuid -> Text,
enabled -> Bool,
created_at -> Timestamp,
updated_at -> Timestamp,
verified_at -> Nullable<Timestamp>,
last_verifying_at -> Nullable<Timestamp>,
login_verify_count -> Integer,
email -> Text,
email_new -> Nullable<Text>,
email_new_token -> Nullable<Text>,
name -> Text,
password_hash -> Binary,
salt -> Binary,
password_iterations -> Integer,
password_hint -> Nullable<Text>,
akey -> Text,
private_key -> Nullable<Text>,
public_key -> Nullable<Text>,
totp_secret -> Nullable<Text>,
totp_recover -> Nullable<Text>,
security_stamp -> Text,
stamp_exception -> Nullable<Text>,
equivalent_domains -> Text,
excluded_globals -> Text,
client_kdf_type -> Integer,
client_kdf_iter -> Integer,
client_kdf_memory -> Nullable<Integer>,
client_kdf_parallelism -> Nullable<Integer>,
api_key -> Nullable<Text>,
avatar_color -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
users_collections (user_uuid, collection_uuid) {
user_uuid -> Text,
collection_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
users_organizations (uuid) {
uuid -> Text,
user_uuid -> Text,
org_uuid -> Text,
invited_by_email -> Nullable<Text>,
access_all -> Bool,
akey -> Text,
status -> Integer,
atype -> Integer,
reset_password_key -> Nullable<Text>,
external_id -> Nullable<Text>,
}
}
table! {
organization_api_key (uuid, org_uuid) {
uuid -> Text,
org_uuid -> Text,
atype -> Integer,
api_key -> Text,
revision_date -> Timestamp,
}
}
table! {
sso_nonce (state) {
state -> Text,
nonce -> Text,
verifier -> Nullable<Text>,
redirect_uri -> Text,
created_at -> Timestamp,
}
}
table! {
sso_users (user_uuid) {
user_uuid -> Text,
identifier -> Text,
}
}
table! {
emergency_access (uuid) {
uuid -> Text,
grantor_uuid -> Text,
grantee_uuid -> Nullable<Text>,
email -> Nullable<Text>,
key_encrypted -> Nullable<Text>,
atype -> Integer,
status -> Integer,
wait_time_days -> Integer,
recovery_initiated_at -> Nullable<Timestamp>,
last_notification_at -> Nullable<Timestamp>,
updated_at -> Timestamp,
created_at -> Timestamp,
}
}
table! {
groups (uuid) {
uuid -> Text,
organizations_uuid -> Text,
name -> Text,
access_all -> Bool,
external_id -> Nullable<Text>,
creation_date -> Timestamp,
revision_date -> Timestamp,
}
}
table! {
groups_users (groups_uuid, users_organizations_uuid) {
groups_uuid -> Text,
users_organizations_uuid -> Text,
}
}
table! {
collections_groups (collections_uuid, groups_uuid) {
collections_uuid -> Text,
groups_uuid -> Text,
read_only -> Bool,
hide_passwords -> Bool,
manage -> Bool,
}
}
table! {
auth_requests (uuid) {
uuid -> Text,
user_uuid -> Text,
organization_uuid -> Nullable<Text>,
request_device_identifier -> Text,
device_type -> Integer,
request_ip -> Text,
response_device_id -> Nullable<Text>,
access_code -> Text,
public_key -> Text,
enc_key -> Nullable<Text>,
master_password_hash -> Nullable<Text>,
approved -> Nullable<Bool>,
creation_date -> Timestamp,
response_date -> Nullable<Timestamp>,
authentication_date -> Nullable<Timestamp>,
}
}
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
joinable!(ciphers_collections -> ciphers (cipher_uuid));
joinable!(ciphers_collections -> collections (collection_uuid));
joinable!(collections -> organizations (org_uuid));
joinable!(devices -> users (user_uuid));
joinable!(folders -> users (user_uuid));
joinable!(folders_ciphers -> ciphers (cipher_uuid));
joinable!(folders_ciphers -> folders (folder_uuid));
joinable!(org_policies -> organizations (org_uuid));
joinable!(sends -> organizations (organization_uuid));
joinable!(sends -> users (user_uuid));
joinable!(twofactor -> users (user_uuid));
joinable!(users_collections -> collections (collection_uuid));
joinable!(users_collections -> users (user_uuid));
joinable!(users_organizations -> organizations (org_uuid));
joinable!(users_organizations -> users (user_uuid));
joinable!(users_organizations -> ciphers (org_uuid));
joinable!(organization_api_key -> organizations (org_uuid));
joinable!(emergency_access -> users (grantor_uuid));
joinable!(groups -> organizations (organizations_uuid));
joinable!(groups_users -> users_organizations (users_organizations_uuid));
joinable!(groups_users -> groups (groups_uuid));
joinable!(collections_groups -> collections (collections_uuid));
joinable!(collections_groups -> groups (groups_uuid));
joinable!(event -> users_organizations (uuid));
joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
attachments,
ciphers,
ciphers_collections,
collections,
devices,
folders,
folders_ciphers,
invitations,
org_policies,
organizations,
sends,
sso_users,
twofactor,
users,
users_collections,
users_organizations,
organization_api_key,
emergency_access,
groups,
groups_users,
collections_groups,
event,
auth_requests,
);

50
src/main.rs

@ -71,7 +71,7 @@ pub use util::is_running_in_container;
#[rocket::main] #[rocket::main]
async fn main() -> Result<(), Error> { async fn main() -> Result<(), Error> {
parse_args().await; parse_args();
launch_info(); launch_info();
let level = init_logging()?; let level = init_logging()?;
@ -87,8 +87,8 @@ async fn main() -> Result<(), Error> {
let pool = create_db_pool().await; let pool = create_db_pool().await;
schedule_jobs(pool.clone()); schedule_jobs(pool.clone());
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_credential_to_passkey(&mut pool.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_credential_to_passkey(&pool.get().await.unwrap()).await.unwrap();
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug); let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
launch_rocket(pool, extra_debug).await // Blocks until program termination. launch_rocket(pool, extra_debug).await // Blocks until program termination.
@ -117,7 +117,7 @@ PRESETS: m= t= p=
pub const VERSION: Option<&str> = option_env!("VW_VERSION"); pub const VERSION: Option<&str> = option_env!("VW_VERSION");
async fn parse_args() { fn parse_args() {
let mut pargs = pico_args::Arguments::from_env(); let mut pargs = pico_args::Arguments::from_env();
let version = VERSION.unwrap_or("(Version info from Git not present)"); let version = VERSION.unwrap_or("(Version info from Git not present)");
@ -188,7 +188,7 @@ async fn parse_args() {
exit(1); exit(1);
} }
} else if command == "backup" { } else if command == "backup" {
match backup_sqlite().await { match db::backup_sqlite() {
Ok(f) => { Ok(f) => {
println!("Backup to '{f}' was successful"); println!("Backup to '{f}' was successful");
exit(0); exit(0);
@ -203,23 +203,6 @@ async fn parse_args() {
} }
} }
async fn backup_sqlite() -> Result<String, Error> {
use crate::db::{backup_database, DbConnType};
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::sqlite).unwrap_or(false) {
// Establish a connection to the sqlite database
let mut conn = db::DbPool::from_config()
.expect("SQLite database connection failed")
.get()
.await
.expect("Unable to get SQLite db pool");
let backup_file = backup_database(&mut conn).await?;
Ok(backup_file)
} else {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
}
fn launch_info() { fn launch_info() {
println!( println!(
"\ "\
@ -285,13 +268,6 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
log::LevelFilter::Off log::LevelFilter::Off
}; };
let diesel_logger_level: log::LevelFilter =
if cfg!(feature = "query_logger") && std::env::var("QUERY_LOGGER").is_ok() {
log::LevelFilter::Debug
} else {
log::LevelFilter::Off
};
// Only show Rocket underscore `_` logs when the level is Debug or higher // Only show Rocket underscore `_` logs when the level is Debug or higher
// Else this will bloat the log output with useless messages. // Else this will bloat the log output with useless messages.
let rocket_underscore_level = if level >= log::LevelFilter::Debug { let rocket_underscore_level = if level >= log::LevelFilter::Debug {
@ -342,9 +318,15 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
// Variable level for hickory used by reqwest // Variable level for hickory used by reqwest
("hickory_resolver::name_server::name_server", hickory_level), ("hickory_resolver::name_server::name_server", hickory_level),
("hickory_proto::xfer", hickory_level), ("hickory_proto::xfer", hickory_level),
("diesel_logger", diesel_logger_level),
// SMTP // SMTP
("lettre::transport::smtp", smtp_log_level), ("lettre::transport::smtp", smtp_log_level),
// Set query_logger default to Off, but can be overwritten manually
// You can set LOG_LEVEL=info,vaultwarden::db::query_logger=<LEVEL> to overwrite it.
// This makes it possible to do the following:
// warn = Print slow queries only, 5 seconds or longer
// info = Print slow queries only, 1 second or longer
// debug = Print all queries
("vaultwarden::db::query_logger", log::LevelFilter::Off),
]); ]);
for (path, level) in levels_override.into_iter() { for (path, level) in levels_override.into_iter() {
@ -614,21 +596,25 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
CONFIG.shutdown(); CONFIG.shutdown();
}); });
#[cfg(unix)] #[cfg(all(unix, sqlite))]
{ {
if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} else {
tokio::spawn(async move { tokio::spawn(async move {
let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap(); let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap();
loop { loop {
// If we need more signals to act upon, we might want to use select! here. // If we need more signals to act upon, we might want to use select! here.
// With only one item to listen for this is enough. // With only one item to listen for this is enough.
let _ = signal_user1.recv().await; let _ = signal_user1.recv().await;
match backup_sqlite().await { match db::backup_sqlite() {
Ok(f) => info!("Backup to '{f}' was successful"), Ok(f) => info!("Backup to '{f}' was successful"),
Err(e) => error!("Backup failed. {e:?}"), Err(e) => error!("Backup failed. {e:?}"),
} }
} }
}); });
} }
}
instance.launch().await?; instance.launch().await?;

15
src/sso.rs

@ -165,12 +165,7 @@ pub fn decode_state(base64_state: String) -> ApiResult<OIDCState> {
// The `nonce` allow to protect against replay attacks // The `nonce` allow to protect against replay attacks
// redirect_uri from: https://github.com/bitwarden/server/blob/main/src/Identity/IdentityServer/ApiClient.cs // redirect_uri from: https://github.com/bitwarden/server/blob/main/src/Identity/IdentityServer/ApiClient.cs
pub async fn authorize_url( pub async fn authorize_url(state: OIDCState, client_id: &str, raw_redirect_uri: &str, conn: DbConn) -> ApiResult<Url> {
state: OIDCState,
client_id: &str,
raw_redirect_uri: &str,
mut conn: DbConn,
) -> ApiResult<Url> {
let redirect_uri = match client_id { let redirect_uri = match client_id {
"web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()), "web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
"desktop" | "mobile" => "bitwarden://sso-callback".to_string(), "desktop" | "mobile" => "bitwarden://sso-callback".to_string(),
@ -185,7 +180,7 @@ pub async fn authorize_url(
}; };
let (auth_url, nonce) = Client::authorize_url(state, redirect_uri).await?; let (auth_url, nonce) = Client::authorize_url(state, redirect_uri).await?;
nonce.save(&mut conn).await?; nonce.save(&conn).await?;
Ok(auth_url) Ok(auth_url)
} }
@ -235,7 +230,7 @@ pub struct UserInformation {
pub user_name: Option<String>, pub user_name: Option<String>,
} }
async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCode, OIDCState)> { async fn decode_code_claims(code: &str, conn: &DbConn) -> ApiResult<(OIDCCode, OIDCState)> {
match auth::decode_jwt::<OIDCCodeClaims>(code, SSO_JWT_ISSUER.to_string()) { match auth::decode_jwt::<OIDCCodeClaims>(code, SSO_JWT_ISSUER.to_string()) {
Ok(code_claims) => match code_claims.code { Ok(code_claims) => match code_claims.code {
OIDCCodeWrapper::Ok { OIDCCodeWrapper::Ok {
@ -265,7 +260,7 @@ async fn decode_code_claims(code: &str, conn: &mut DbConn) -> ApiResult<(OIDCCod
// - second time we will rely on the `AC_CACHE` since the `code` has already been exchanged. // - second time we will rely on the `AC_CACHE` since the `code` has already been exchanged.
// The `nonce` will ensure that the user is authorized only once. // The `nonce` will ensure that the user is authorized only once.
// We return only the `UserInformation` to force calling `redeem` to obtain the `refresh_token`. // We return only the `UserInformation` to force calling `redeem` to obtain the `refresh_token`.
pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<UserInformation> { pub async fn exchange_code(wrapped_code: &str, conn: &DbConn) -> ApiResult<UserInformation> {
use openidconnect::OAuth2TokenResponse; use openidconnect::OAuth2TokenResponse;
let (code, state) = decode_code_claims(wrapped_code, conn).await?; let (code, state) = decode_code_claims(wrapped_code, conn).await?;
@ -330,7 +325,7 @@ pub async fn exchange_code(wrapped_code: &str, conn: &mut DbConn) -> ApiResult<U
} }
// User has passed 2FA flow we can delete `nonce` and clear the cache. // User has passed 2FA flow we can delete `nonce` and clear the cache.
pub async fn redeem(state: &OIDCState, conn: &mut DbConn) -> ApiResult<AuthenticatedUser> { pub async fn redeem(state: &OIDCState, conn: &DbConn) -> ApiResult<AuthenticatedUser> {
if let Err(err) = SsoNonce::delete(state, conn).await { if let Err(err) = SsoNonce::delete(state, conn).await {
error!("Failed to delete database sso_nonce using {state}: {err}") error!("Failed to delete database sso_nonce using {state}: {err}")
} }

Loading…
Cancel
Save