Browse Source

Merge branch 'main' into permit-subpath-admin-page-fix

pull/2713/head
GeekCorner 3 years ago
committed by GitHub
parent
commit
e0354077ba
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      .dockerignore
  2. 27
      .env.template
  3. 186
      .github/workflows/build.yml
  4. 6
      .github/workflows/hadolint.yml
  5. 6
      .github/workflows/release.yml
  6. 12
      .pre-commit-config.yaml
  7. 2787
      Cargo.lock
  8. 162
      Cargo.toml
  9. 2
      Rocket.toml
  10. 1
      docker/Dockerfile.buildx
  11. 79
      docker/Dockerfile.j2
  12. 26
      docker/amd64/Dockerfile
  13. 28
      docker/amd64/Dockerfile.alpine
  14. 26
      docker/amd64/Dockerfile.buildx
  15. 28
      docker/amd64/Dockerfile.buildx.alpine
  16. 26
      docker/arm64/Dockerfile
  17. 28
      docker/arm64/Dockerfile.alpine
  18. 26
      docker/arm64/Dockerfile.buildx
  19. 28
      docker/arm64/Dockerfile.buildx.alpine
  20. 31
      docker/armv6/Dockerfile
  21. 30
      docker/armv6/Dockerfile.alpine
  22. 31
      docker/armv6/Dockerfile.buildx
  23. 30
      docker/armv6/Dockerfile.buildx.alpine
  24. 26
      docker/armv7/Dockerfile
  25. 31
      docker/armv7/Dockerfile.alpine
  26. 26
      docker/armv7/Dockerfile.buildx
  27. 31
      docker/armv7/Dockerfile.buildx.alpine
  28. 4
      docker/healthcheck.sh
  29. 8
      docker/start.sh
  30. 0
      migrations/mysql/2022-03-02-210038_update_devices_primary_key/down.sql
  31. 4
      migrations/mysql/2022-03-02-210038_update_devices_primary_key/up.sql
  32. 0
      migrations/postgresql/2022-03-02-210038_update_devices_primary_key/down.sql
  33. 4
      migrations/postgresql/2022-03-02-210038_update_devices_primary_key/up.sql
  34. 0
      migrations/sqlite/2022-03-02-210038_update_devices_primary_key/down.sql
  35. 23
      migrations/sqlite/2022-03-02-210038_update_devices_primary_key/up.sql
  36. 2
      rust-toolchain
  37. 8
      rustfmt.toml
  38. 300
      src/api/admin.rs
  39. 212
      src/api/core/accounts.rs
  40. 845
      src/api/core/ciphers.rs
  41. 257
      src/api/core/emergency_access.rs
  42. 53
      src/api/core/folders.rs
  43. 46
      src/api/core/mod.rs
  44. 458
      src/api/core/organizations.rs
  45. 189
      src/api/core/sends.rs
  46. 45
      src/api/core/two_factor/authenticator.rs
  47. 46
      src/api/core/two_factor/duo.rs
  48. 78
      src/api/core/two_factor/email.rs
  49. 59
      src/api/core/two_factor/mod.rs
  50. 352
      src/api/core/two_factor/u2f.rs
  51. 91
      src/api/core/two_factor/webauthn.rs
  52. 27
      src/api/core/two_factor/yubikey.rs
  53. 627
      src/api/icons.rs
  54. 209
      src/api/identity.rs
  55. 2
      src/api/mod.rs
  56. 391
      src/api/notifications.rs
  57. 55
      src/api/web.rs
  58. 297
      src/auth.rs
  59. 126
      src/config.rs
  60. 28
      src/crypto.rs
  61. 220
      src/db/mod.rs
  62. 37
      src/db/models/attachment.rs
  63. 280
      src/db/models/cipher.rs
  64. 143
      src/db/models/collection.rs
  65. 50
      src/db/models/device.rs
  66. 45
      src/db/models/emergency_access.rs
  67. 32
      src/db/models/favorite.rs
  68. 54
      src/db/models/folder.rs
  69. 41
      src/db/models/org_policy.rs
  70. 96
      src/db/models/organization.rs
  71. 50
      src/db/models/send.rs
  72. 26
      src/db/models/two_factor.rs
  73. 25
      src/db/models/two_factor_incomplete.rs
  74. 80
      src/db/models/user.rs
  75. 2
      src/db/schemas/mysql/schema.rs
  76. 2
      src/db/schemas/postgresql/schema.rs
  77. 2
      src/db/schemas/sqlite/schema.rs
  78. 34
      src/error.rs
  79. 136
      src/mail.rs
  80. 206
      src/main.rs
  81. 17
      src/static/global_domains.json
  82. 6042
      src/static/scripts/bootstrap-native.js
  83. 3674
      src/static/scripts/bootstrap.css
  84. 276
      src/static/scripts/datatables.css
  85. 460
      src/static/scripts/datatables.js
  86. 7
      src/static/templates/admin/base.hbs
  87. 4
      src/static/templates/admin/login.hbs
  88. 63
      src/static/templates/admin/settings.hbs
  89. 172
      src/util.rs

17
.dockerignore

@ -3,13 +3,18 @@ target
# Data folder # Data folder
data data
# Misc
.env .env
.env.template .env.template
.gitattributes .gitattributes
.gitignore
rustfmt.toml
# IDE files # IDE files
.vscode .vscode
.idea .idea
.editorconfig
*.iml *.iml
# Documentation # Documentation
@ -19,9 +24,17 @@ data
*.yml *.yml
*.yaml *.yaml
# Docker folders # Docker
hooks hooks
tools tools
Dockerfile
.dockerignore
docker/**
!docker/healthcheck.sh
!docker/start.sh
# Web vault # Web vault
web-vault web-vault
# Vaultwarden Resources
resources

27
.env.template

@ -3,6 +3,11 @@
## ##
## Be aware that most of these settings will be overridden if they were changed ## Be aware that most of these settings will be overridden if they were changed
## in the admin interface. Those overrides are stored within DATA_FOLDER/config.json . ## in the admin interface. Those overrides are stored within DATA_FOLDER/config.json .
##
## By default, vaultwarden expects for this file to be named ".env" and located
## in the current working directory. If this is not the case, the environment
## variable ENV_FILE can be set to the location of this file prior to starting
## vaultwarden.
## Main data folder ## Main data folder
# DATA_FOLDER=data # DATA_FOLDER=data
@ -24,11 +29,21 @@
## Define the size of the connection pool used for connecting to the database. ## Define the size of the connection pool used for connecting to the database.
# DATABASE_MAX_CONNS=10 # DATABASE_MAX_CONNS=10
## Database connection initialization
## Allows SQL statements to be run whenever a new database connection is created.
## This is mainly useful for connection-scoped pragmas.
## If empty, a database-specific default is used:
## - SQLite: "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;"
## - MySQL: ""
## - PostgreSQL: ""
# DATABASE_CONN_INIT=""
## Individual folders, these override %DATA_FOLDER% ## Individual folders, these override %DATA_FOLDER%
# RSA_KEY_FILENAME=data/rsa_key # RSA_KEY_FILENAME=data/rsa_key
# ICON_CACHE_FOLDER=data/icon_cache # ICON_CACHE_FOLDER=data/icon_cache
# ATTACHMENTS_FOLDER=data/attachments # ATTACHMENTS_FOLDER=data/attachments
# SENDS_FOLDER=data/sends # SENDS_FOLDER=data/sends
# TMP_FOLDER=data/tmp
## Templates data folder, by default uses embedded templates ## Templates data folder, by default uses embedded templates
## Check source code to see the format ## Check source code to see the format
@ -102,12 +117,10 @@
# LOG_TIMESTAMP_FORMAT="%Y-%m-%d %H:%M:%S.%3f" # LOG_TIMESTAMP_FORMAT="%Y-%m-%d %H:%M:%S.%3f"
## Logging to file ## Logging to file
## It's recommended to also set 'ROCKET_CLI_COLORS=off'
# LOG_FILE=/path/to/log # LOG_FILE=/path/to/log
## Logging to Syslog ## Logging to Syslog
## This requires extended logging ## This requires extended logging
## It's recommended to also set 'ROCKET_CLI_COLORS=off'
# USE_SYSLOG=false # USE_SYSLOG=false
## Log level ## Log level
@ -185,7 +198,7 @@
# EMAIL_EXPIRATION_TIME=600 # EMAIL_EXPIRATION_TIME=600
## Email token size ## Email token size
## Number of digits in an email token (min: 6, max: 19). ## Number of digits in an email 2FA token (min: 6, max: 255).
## Note that the Bitwarden clients are hardcoded to mention 6 digit codes regardless of this setting! ## Note that the Bitwarden clients are hardcoded to mention 6 digit codes regardless of this setting!
# EMAIL_TOKEN_SIZE=6 # EMAIL_TOKEN_SIZE=6
@ -257,6 +270,9 @@
## The change only applies when the password is changed ## The change only applies when the password is changed
# PASSWORD_ITERATIONS=100000 # PASSWORD_ITERATIONS=100000
## Controls whether users can set password hints. This setting applies globally to all users.
# PASSWORD_HINTS_ALLOWED=true
## Controls whether a password hint should be shown directly in the web page if ## Controls whether a password hint should be shown directly in the web page if
## SMTP service is not configured. Not recommended for publicly-accessible instances ## SMTP service is not configured. Not recommended for publicly-accessible instances
## as this provides unauthenticated access to potentially sensitive data. ## as this provides unauthenticated access to potentially sensitive data.
@ -267,7 +283,7 @@
## It's recommended to configure this value, otherwise certain functionality might not work, ## It's recommended to configure this value, otherwise certain functionality might not work,
## like attachment downloads, email links and U2F. ## like attachment downloads, email links and U2F.
## For U2F to work, the server must use HTTPS, you can use Let's Encrypt for free certs ## For U2F to work, the server must use HTTPS, you can use Let's Encrypt for free certs
# DOMAIN=https://bw.domain.tld:8443 # DOMAIN=https://vw.domain.tld:8443
## Allowed iframe ancestors (Know the risks!) ## Allowed iframe ancestors (Know the risks!)
## https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/frame-ancestors ## https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/frame-ancestors
@ -331,9 +347,8 @@
# SMTP_HOST=smtp.domain.tld # SMTP_HOST=smtp.domain.tld
# SMTP_FROM=vaultwarden@domain.tld # SMTP_FROM=vaultwarden@domain.tld
# SMTP_FROM_NAME=Vaultwarden # SMTP_FROM_NAME=Vaultwarden
# SMTP_SECURITY=starttls # ("starttls", "force_tls", "off") Enable a secure connection. Default is "starttls" (Explicit - ports 587 or 25), "force_tls" (Implicit - port 465) or "off", no encryption (port 25)
# SMTP_PORT=587 # Ports 587 (submission) and 25 (smtp) are standard without encryption and with encryption via STARTTLS (Explicit TLS). Port 465 is outdated and used with Implicit TLS. # SMTP_PORT=587 # Ports 587 (submission) and 25 (smtp) are standard without encryption and with encryption via STARTTLS (Explicit TLS). Port 465 is outdated and used with Implicit TLS.
# SMTP_SSL=true # (Explicit) - This variable by default configures Explicit STARTTLS, it will upgrade an insecure connection to a secure one. Unless SMTP_EXPLICIT_TLS is set to true. Either port 587 or 25 are default.
# SMTP_EXPLICIT_TLS=true # (Implicit) - N.B. This variable configures Implicit TLS. It's currently mislabelled (see bug #851) - SMTP_SSL Needs to be set to true for this option to work. Usually port 465 is used here.
# SMTP_USERNAME=username # SMTP_USERNAME=username
# SMTP_PASSWORD=password # SMTP_PASSWORD=password
# SMTP_TIMEOUT=15 # SMTP_TIMEOUT=15

186
.github/workflows/build.yml

@ -8,7 +8,6 @@ on:
- "migrations/**" - "migrations/**"
- "Cargo.*" - "Cargo.*"
- "build.rs" - "build.rs"
- "diesel.toml"
- "rust-toolchain" - "rust-toolchain"
pull_request: pull_request:
paths: paths:
@ -17,11 +16,11 @@ on:
- "migrations/**" - "migrations/**"
- "Cargo.*" - "Cargo.*"
- "build.rs" - "build.rs"
- "diesel.toml"
- "rust-toolchain" - "rust-toolchain"
jobs: jobs:
build: build:
runs-on: ubuntu-20.04
# Make warnings errors, this is to prevent warnings slipping through. # Make warnings errors, this is to prevent warnings slipping through.
# This is done globally to prevent rebuilds when the RUSTFLAGS env variable changes. # This is done globally to prevent rebuilds when the RUSTFLAGS env variable changes.
env: env:
@ -30,118 +29,169 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
channel: channel:
- nightly - "rust-toolchain" # The version defined in rust-toolchain
target-triple: - "1.60.0" # The supported MSRV
- x86_64-unknown-linux-gnu
include: name: Build and Test ${{ matrix.channel }}
- target-triple: x86_64-unknown-linux-gnu
host-triple: x86_64-unknown-linux-gnu
features: [sqlite,mysql,postgresql] # Remember to update the `cargo test` to match the amount of features
channel: nightly
os: ubuntu-20.04
ext: ""
name: Building ${{ matrix.channel }}-${{ matrix.target-triple }}
runs-on: ${{ matrix.os }}
steps: steps:
# Checkout the repo # Checkout the repo
- name: Checkout - name: "Checkout"
uses: actions/checkout@5a4ac9002d0be2fb38bd78e4b4dbde5606d7042f # v2.3.4 uses: actions/checkout@2541b1294d2704b0964813337f33b291d3f8596b # v3.0.2
# End Checkout the repo # End Checkout the repo
# Install musl-tools when needed
- name: Install musl tools
run: sudo apt-get update && sudo apt-get install -y --no-install-recommends musl-dev musl-tools cmake
if: matrix.target-triple == 'x86_64-unknown-linux-musl'
# End Install musl-tools when needed
# Install dependencies # Install dependencies
- name: Install dependencies Ubuntu - name: "Install dependencies Ubuntu"
run: sudo apt-get update && sudo apt-get install -y --no-install-recommends openssl sqlite build-essential libmariadb-dev-compat libpq-dev libssl-dev pkgconf run: sudo apt-get update && sudo apt-get install -y --no-install-recommends openssl sqlite build-essential libmariadb-dev-compat libpq-dev libssl-dev pkg-config
if: startsWith( matrix.os, 'ubuntu' )
# End Install dependencies # End Install dependencies
# Enable Rust Caching
- uses: Swatinem/rust-cache@842ef286fff290e445b90b4002cc9807c3669641 # v1.3.0
# End Enable Rust Caching
# Uses the rust-toolchain file to determine version # Uses the rust-toolchain file to determine version
- name: 'Install ${{ matrix.channel }}-${{ matrix.host-triple }} for target: ${{ matrix.target-triple }}' - name: "Install rust-toolchain version"
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f # v1.0.6 uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f # v1.0.6
if: ${{ matrix.channel == 'rust-toolchain' }}
with: with:
profile: minimal profile: minimal
target: ${{ matrix.target-triple }}
components: clippy, rustfmt components: clippy, rustfmt
# End Uses the rust-toolchain file to determine version # End Uses the rust-toolchain file to determine version
# Install the MSRV channel to be used
- name: "Install MSRV version"
uses: actions-rs/toolchain@b2417cde72dcf67f306c0ae8e0828a81bf0b189f # v1.0.6
if: ${{ matrix.channel != 'rust-toolchain' }}
with:
profile: minimal
override: true
toolchain: ${{ matrix.channel }}
# End Install the MSRV channel to be used
# Enable Rust Caching
- uses: Swatinem/rust-cache@6720f05bc48b77f96918929a9019fb2203ff71f8 # v2.0.0
# End Enable Rust Caching
# Show environment
- name: "Show environment"
run: |
rustc -vV
cargo -vV
# End Show environment
# Run cargo tests (In release mode to speed up future builds) # Run cargo tests (In release mode to speed up future builds)
# First test all features together, afterwards test them separately. # First test all features together, afterwards test them separately.
- name: "`cargo test --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }}`" - name: "test features: sqlite,mysql,postgresql,enable_mimalloc"
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 id: test_sqlite_mysql_postgresql_mimalloc
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: $${{ always() }}
with: with:
command: test command: test
args: --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }} args: --release --features sqlite,mysql,postgresql,enable_mimalloc
# Test single features
# 0: sqlite - name: "test features: sqlite,mysql,postgresql"
- name: "`cargo test --release --features ${{ matrix.features[0] }} --target ${{ matrix.target-triple }}`" id: test_sqlite_mysql_postgresql
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: $${{ always() }}
with:
command: test
args: --release --features sqlite,mysql,postgresql
- name: "test features: sqlite"
id: test_sqlite
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: $${{ always() }}
with: with:
command: test command: test
args: --release --features ${{ matrix.features[0] }} --target ${{ matrix.target-triple }} args: --release --features sqlite
if: ${{ matrix.features[0] != '' }}
# 1: mysql - name: "test features: mysql"
- name: "`cargo test --release --features ${{ matrix.features[1] }} --target ${{ matrix.target-triple }}`" id: test_mysql
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: $${{ always() }}
with: with:
command: test command: test
args: --release --features ${{ matrix.features[1] }} --target ${{ matrix.target-triple }} args: --release --features mysql
if: ${{ matrix.features[1] != '' }}
# 2: postgresql - name: "test features: postgresql"
- name: "`cargo test --release --features ${{ matrix.features[2] }} --target ${{ matrix.target-triple }}`" id: test_postgresql
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: $${{ always() }}
with: with:
command: test command: test
args: --release --features ${{ matrix.features[2] }} --target ${{ matrix.target-triple }} args: --release --features postgresql
if: ${{ matrix.features[2] != '' }}
# End Run cargo tests # End Run cargo tests
# Run cargo clippy, and fail on warnings (In release mode to speed up future builds) # Run cargo clippy, and fail on warnings (In release mode to speed up future builds)
- name: "`cargo clippy --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }}`" - name: "clippy features: sqlite,mysql,postgresql,enable_mimalloc"
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 id: clippy
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: ${{ always() && matrix.channel == 'rust-toolchain' }}
with: with:
command: clippy command: clippy
args: --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }} -- -D warnings args: --release --features sqlite,mysql,postgresql,enable_mimalloc -- -D warnings
# End Run cargo clippy # End Run cargo clippy
# Run cargo fmt # Run cargo fmt (Only run on rust-toolchain defined version)
- name: '`cargo fmt`' - name: "check formatting"
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 id: formatting
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: ${{ always() && matrix.channel == 'rust-toolchain' }}
with: with:
command: fmt command: fmt
args: --all -- --check args: --all -- --check
# End Run cargo fmt # End Run cargo fmt
# Build the binary # Check for any previous failures, if there are stop, else continue.
- name: "`cargo build --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }}`" # This is useful so all test/clippy/fmt actions are done, and they can all be addressed
uses: actions-rs/cargo@ae10961054e4aa8b4aa7dffede299aaf087aa33b # v1.0.1 - name: "Some checks failed"
if: ${{ failure() }}
run: |
echo "### :x: Checks Failed!" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "|Job|Status|" >> $GITHUB_STEP_SUMMARY
echo "|---|------|" >> $GITHUB_STEP_SUMMARY
echo "|test (sqlite,mysql,postgresql,enable_mimalloc)|${{ steps.test_sqlite_mysql_postgresql_mimalloc.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|test (sqlite,mysql,postgresql)|${{ steps.test_sqlite_mysql_postgresql.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|test (sqlite)|${{ steps.test_sqlite.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|test (mysql)|${{ steps.test_mysql.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|test (postgresql)|${{ steps.test_postgresql.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|clippy (sqlite,mysql,postgresql,enable_mimalloc)|${{ steps.clippy.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "|fmt|${{ steps.formatting.outcome }}|" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Please check the failed jobs and fix where needed." >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
exit 1
# Check for any previous failures, if there are stop, else continue.
# This is useful so all test/clippy/fmt actions are done, and they can all be addressed
- name: "All checks passed"
if: ${{ success() }}
run: |
echo "### :tada: Checks Passed!" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
# Build the binary to upload to the artifacts
- name: "build features: sqlite,mysql,postgresql"
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1.0.3
if: ${{ matrix.channel == 'rust-toolchain' }}
with: with:
command: build command: build
args: --release --features ${{ join(matrix.features, ',') }} --target ${{ matrix.target-triple }} args: --release --features sqlite,mysql,postgresql
# End Build the binary # End Build the binary
# Upload artifact to Github Actions # Upload artifact to Github Actions
- name: Upload artifact - name: "Upload artifact"
uses: actions/upload-artifact@27121b0bdffd731efa15d66772be8dc71245d074 # v2.2.4 uses: actions/upload-artifact@3cea5372237819ed00197afe530f5a7ea3e805c8 # v3.1.0
if: ${{ matrix.channel == 'rust-toolchain' }}
with: with:
name: vaultwarden-${{ matrix.target-triple }}${{ matrix.ext }} name: vaultwarden
path: target/${{ matrix.target-triple }}/release/vaultwarden${{ matrix.ext }} path: target/${{ matrix.target-triple }}/release/vaultwarden
# End Upload artifact to Github Actions # End Upload artifact to Github Actions

6
.github/workflows/hadolint.yml

@ -16,18 +16,18 @@ jobs:
steps: steps:
# Checkout the repo # Checkout the repo
- name: Checkout - name: Checkout
uses: actions/checkout@5a4ac9002d0be2fb38bd78e4b4dbde5606d7042f # v2.3.4 uses: actions/checkout@2541b1294d2704b0964813337f33b291d3f8596b # v3.0.2
# End Checkout the repo # End Checkout the repo
# Download hadolint # Download hadolint - https://github.com/hadolint/hadolint/releases
- name: Download hadolint - name: Download hadolint
shell: bash shell: bash
run: | run: |
sudo curl -L https://github.com/hadolint/hadolint/releases/download/v${HADOLINT_VERSION}/hadolint-$(uname -s)-$(uname -m) -o /usr/local/bin/hadolint && \ sudo curl -L https://github.com/hadolint/hadolint/releases/download/v${HADOLINT_VERSION}/hadolint-$(uname -s)-$(uname -m) -o /usr/local/bin/hadolint && \
sudo chmod +x /usr/local/bin/hadolint sudo chmod +x /usr/local/bin/hadolint
env: env:
HADOLINT_VERSION: 2.7.0 HADOLINT_VERSION: 2.10.0
# End Download hadolint # End Download hadolint
# Test Dockerfiles # Test Dockerfiles

6
.github/workflows/release.yml

@ -31,7 +31,7 @@ jobs:
steps: steps:
- name: Skip Duplicates Actions - name: Skip Duplicates Actions
id: skip_check id: skip_check
uses: fkirc/skip-duplicate-actions@f75dd6564bb646f95277dc8c3b80612e46a4a1ea # v3.4.1 uses: fkirc/skip-duplicate-actions@9d116fa7e55f295019cfab7e3ab72b478bcf7fdd # v4.0.0
with: with:
cancel_others: 'true' cancel_others: 'true'
# Only run this when not creating a tag # Only run this when not creating a tag
@ -60,13 +60,13 @@ jobs:
steps: steps:
# Checkout the repo # Checkout the repo
- name: Checkout - name: Checkout
uses: actions/checkout@5a4ac9002d0be2fb38bd78e4b4dbde5606d7042f # v2.3.4 uses: actions/checkout@2541b1294d2704b0964813337f33b291d3f8596b # v3.0.2
with: with:
fetch-depth: 0 fetch-depth: 0
# Login to Docker Hub # Login to Docker Hub
- name: Login to Docker Hub - name: Login to Docker Hub
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 # v1.10.0 uses: docker/login-action@49ed152c8eca782a232dede0303416e8f356c37b # v2.0.0
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}

12
.pre-commit-config.yaml

@ -1,7 +1,7 @@
--- ---
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v4.3.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: check-json - id: check-json
@ -25,14 +25,16 @@ repos:
description: Test the package for errors. description: Test the package for errors.
entry: cargo test entry: cargo test
language: system language: system
args: ["--features", "sqlite,mysql,postgresql", "--"] args: ["--features", "sqlite,mysql,postgresql,enable_mimalloc", "--"]
types: [rust] types_or: [rust, file]
files: (Cargo.toml|Cargo.lock|.*\.rs$)
pass_filenames: false pass_filenames: false
- id: cargo-clippy - id: cargo-clippy
name: cargo clippy name: cargo clippy
description: Lint Rust sources description: Lint Rust sources
entry: cargo clippy entry: cargo clippy
language: system language: system
args: ["--features", "sqlite,mysql,postgresql", "--", "-D", "warnings"] args: ["--features", "sqlite,mysql,postgresql,enable_mimalloc", "--", "-D", "warnings"]
types: [rust] types_or: [rust, file]
files: (Cargo.toml|Cargo.lock|.*\.rs$)
pass_filenames: false pass_filenames: false

2787
Cargo.lock

File diff suppressed because it is too large

162
Cargo.toml

@ -3,7 +3,7 @@ name = "vaultwarden"
version = "1.0.0" version = "1.0.0"
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"] authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
edition = "2021" edition = "2021"
rust-version = "1.60" rust-version = "1.60.0"
resolver = "2" resolver = "2"
repository = "https://github.com/dani-garcia/vaultwarden" repository = "https://github.com/dani-garcia/vaultwarden"
@ -13,6 +13,7 @@ publish = false
build = "build.rs" build = "build.rs"
[features] [features]
# default = ["sqlite"]
# Empty to keep compatibility, prefer to set USE_SYSLOG=true # Empty to keep compatibility, prefer to set USE_SYSLOG=true
enable_syslog = [] enable_syslog = []
mysql = ["diesel/mysql", "diesel_migrations/mysql"] mysql = ["diesel/mysql", "diesel_migrations/mysql"]
@ -20,135 +21,138 @@ postgresql = ["diesel/postgres", "diesel_migrations/postgres"]
sqlite = ["diesel/sqlite", "diesel_migrations/sqlite", "libsqlite3-sys"] sqlite = ["diesel/sqlite", "diesel_migrations/sqlite", "libsqlite3-sys"]
# Enable to use a vendored and statically linked openssl # Enable to use a vendored and statically linked openssl
vendored_openssl = ["openssl/vendored"] vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds
enable_mimalloc = ["mimalloc"]
# Enable unstable features, requires nightly # Enable unstable features, requires nightly
# Currently only used to enable rusts official ip support # Currently only used to enable rusts official ip support
unstable = [] unstable = []
[target."cfg(not(windows))".dependencies] [target."cfg(not(windows))".dependencies]
syslog = "4.0.1" # Logging
syslog = "6.0.1" # Needs to be v4 until fern is updated
[dependencies] [dependencies]
# Web framework for nightly with a focus on ease-of-use, expressibility, and speed. # Logging
rocket = { version = "=0.5.0-dev", features = ["tls"], default-features = false } log = "0.4.17"
rocket_contrib = "=0.5.0-dev" fern = { version = "0.6.1", features = ["syslog-6"] }
tracing = { version = "0.1.36", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work
# HTTP client backtrace = "0.3.66" # Logging panics to logfile instead stderr only
reqwest = { version = "0.11.8", features = ["blocking", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
# Used for custom short lived cookie jar # A `dotenv` implementation for Rust
cookie = "0.15.1" dotenvy = { version = "0.15.1", default-features = false }
cookie_store = "0.15.1"
bytes = "1.1.0" # Lazy initialization
url = "2.2.2" once_cell = "1.13.0"
# multipart/form-data support # Numerical libraries
multipart = { version = "0.18.0", features = ["server"], default-features = false } num-traits = "0.2.15"
num-derive = "0.3.3"
# WebSockets library # Web framework
ws = { version = "0.11.1", package = "parity-ws" } rocket = { version = "0.5.0-rc.2", features = ["tls", "json"], default-features = false }
# MessagePack library # WebSockets libraries
rmpv = "1.0.0" tokio-tungstenite = "0.17.2"
rmpv = "1.0.0" # MessagePack library
dashmap = "5.3.4" # Concurrent hashmap implementation
# Concurrent hashmap implementation # Async futures
chashmap = "2.2.2" futures = "0.3.21"
tokio = { version = "1.20.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time"] }
# A generic serialization/deserialization framework # A generic serialization/deserialization framework
serde = { version = "1.0.132", features = ["derive"] } serde = { version = "1.0.142", features = ["derive"] }
serde_json = "1.0.73" serde_json = "1.0.83"
# Logging
log = "0.4.14"
fern = { version = "0.6.0", features = ["syslog-4"] }
# A safe, extensible ORM and Query builder # A safe, extensible ORM and Query builder
diesel = { version = "1.4.8", features = [ "chrono", "r2d2"] } diesel = { version = "1.4.8", features = ["chrono", "r2d2"] }
diesel_migrations = "1.4.0" diesel_migrations = "1.4.0"
# Bundled SQLite # Bundled SQLite
libsqlite3-sys = { version = "0.22.2", features = ["bundled"], optional = true } libsqlite3-sys = { version = "0.22.2", features = ["bundled"], optional = true }
# Crypto-related libraries # Crypto-related libraries
rand = "0.8.4" rand = { version = "0.8.5", features = ["small_rng"] }
ring = "0.16.20" ring = "0.16.20"
# UUID generation # UUID generation
uuid = { version = "0.8.2", features = ["v4"] } uuid = { version = "1.1.2", features = ["v4"] }
# Date and time libraries # Date and time libraries
chrono = { version = "0.4.19", features = ["serde"] } chrono = { version = "0.4.20", features = ["clock", "serde"], default-features = false }
chrono-tz = "0.6.1" chrono-tz = "0.6.3"
time = "0.2.27" time = "0.3.12"
# Job scheduler # Job scheduler
job_scheduler = "1.2.1" job_scheduler_ng = "2.0.1"
# TOTP library # Data encoding library Hex/Base32/Base64
totp-lite = "1.0.3"
# Data encoding library
data-encoding = "2.3.2" data-encoding = "2.3.2"
# JWT library # JWT library
jsonwebtoken = "7.2.0" jsonwebtoken = "8.1.1"
# U2F library # TOTP library
u2f = "0.2.0" totp-lite = "2.0.0"
webauthn-rs = "0.3.1"
# Yubico Library # Yubico Library
yubico = { version = "0.10.0", features = ["online-tokio"], default-features = false } yubico = { version = "0.11.0", features = ["online-tokio"], default-features = false }
# A `dotenv` implementation for Rust # WebAuthn libraries
dotenv = { version = "0.15.0", default-features = false } webauthn-rs = "0.3.2"
# Lazy initialization
once_cell = "1.9.0"
# Numerical libraries # Handling of URL's for WebAuthn
num-traits = "0.2.14" url = "2.2.2"
num-derive = "0.3.3"
# Email libraries # Email librariese-Base, Update crates and small change.
tracing = { version = "0.1.29", features = ["log"] } # Needed to have lettre trace logging used when SMTP_DEBUG is enabled. lettre = { version = "0.10.1", features = ["smtp-transport", "builder", "serde", "tokio1-native-tls", "hostname", "tracing", "tokio1"], default-features = false }
lettre = { version = "0.10.0-rc.4", features = ["smtp-transport", "builder", "serde", "native-tls", "hostname", "tracing"], default-features = false } percent-encoding = "2.1.0" # URL encoding library used for URL's in the emails
# Template library # Template library
handlebars = { version = "4.1.6", features = ["dir_source"] } handlebars = { version = "4.3.3", features = ["dir_source"] }
# HTTP client
reqwest = { version = "0.11.11", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
# For favicon extraction from main website # For favicon extraction from main website
html5ever = "0.25.1" html5gum = "0.5.2"
markup5ever_rcdom = "0.1.0" regex = { version = "1.6.0", features = ["std", "perf", "unicode-perl"], default-features = false }
regex = { version = "1.5.4", features = ["std", "perf", "unicode-perl"], default-features = false }
data-url = "0.1.1" data-url = "0.1.1"
bytes = "1.2.1"
cached = "0.38.0"
# Used by U2F, JWT and Postgres # Used for custom short lived cookie jar during favicon extraction
openssl = "0.10.38" cookie = "0.16.0"
cookie_store = "0.16.1"
# URL encoding library # Used by U2F, JWT and Postgres
percent-encoding = "2.1.0" openssl = "0.10.41"
# Punycode conversion
idna = "0.2.3"
# CLI argument parsing # CLI argument parsing
pico-args = "0.4.2" pico-args = "0.5.0"
# Logging panics to logfile instead stderr only
backtrace = "0.3.63"
# Macro ident concatenation # Macro ident concatenation
paste = "1.0.6" paste = "1.0.8"
governor = "0.3.2" governor = "0.4.2"
# Capture CTRL+C
ctrlc = { version = "3.2.2", features = ["termination"] }
# Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow
mimalloc = { version = "0.1.29", features = ["secure"], default-features = false, optional = true }
[patch.crates-io] [patch.crates-io]
# Use newest ring # Using a patched version of multer-rs (Used by Rocket) to fix attachment/send file uploads
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' } # Issue: https://github.com/dani-garcia/vaultwarden/issues/2644
rocket_contrib = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' } # Patch: https://github.com/BlackDex/multer-rs/commit/73e83fa5eb183646cc56606e5d902acb30a45b3d
multer = { git = "https://github.com/BlackDex/multer-rs", rev = "73e83fa5eb183646cc56606e5d902acb30a45b3d" }
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
# to any issues or PRs for almost a year (as of April 2021). This hopefully # Strip debuginfo from the release builds
# temporary fork updates Cargo.toml to use more up-to-date dependencies. # Also enable thin LTO for some optimizations
# In particular, `cron` has since implemented parsing of some common syntax [profile.release]
# that wasn't previously supported (https://github.com/zslayton/cron/pull/64). strip = "debuginfo"
job_scheduler = { git = 'https://github.com/jjlin/job_scheduler', rev = 'ee023418dbba2bfe1e30a5fd7d937f9e33739806' } lto = "thin"

2
Rocket.toml

@ -1,2 +0,0 @@
[global.limits]
json = 10485760 # 10 MiB

1
docker/Dockerfile.buildx

@ -1,3 +1,4 @@
# syntax=docker/dockerfile:1
# The cross-built images have the build arch (`amd64`) embedded in the image # The cross-built images have the build arch (`amd64`) embedded in the image
# manifest, rather than the target arch. For example: # manifest, rather than the target arch. For example:
# #

79
docker/Dockerfile.j2

@ -3,39 +3,39 @@
# This file was generated using a Jinja2 template. # This file was generated using a Jinja2 template.
# Please make your changes in `Dockerfile.j2` and then `make` the individual Dockerfiles. # Please make your changes in `Dockerfile.j2` and then `make` the individual Dockerfiles.
{% set build_stage_base_image = "rust:1.58-buster" %} {% set build_stage_base_image = "rust:1.61-bullseye" %}
{% if "alpine" in target_file %} {% if "alpine" in target_file %}
{% if "amd64" in target_file %} {% if "amd64" in target_file %}
{% set build_stage_base_image = "blackdex/rust-musl:x86_64-musl-nightly-2022-01-23" %} {% set build_stage_base_image = "blackdex/rust-musl:x86_64-musl-stable-1.61.0" %}
{% set runtime_stage_base_image = "alpine:3.15" %} {% set runtime_stage_base_image = "alpine:3.16" %}
{% set package_arch_target = "x86_64-unknown-linux-musl" %} {% set package_arch_target = "x86_64-unknown-linux-musl" %}
{% elif "armv7" in target_file %} {% elif "armv7" in target_file %}
{% set build_stage_base_image = "blackdex/rust-musl:armv7-musleabihf-nightly-2022-01-23" %} {% set build_stage_base_image = "blackdex/rust-musl:armv7-musleabihf-stable-1.61.0" %}
{% set runtime_stage_base_image = "balenalib/armv7hf-alpine:3.15" %} {% set runtime_stage_base_image = "balenalib/armv7hf-alpine:3.16" %}
{% set package_arch_target = "armv7-unknown-linux-musleabihf" %} {% set package_arch_target = "armv7-unknown-linux-musleabihf" %}
{% elif "armv6" in target_file %} {% elif "armv6" in target_file %}
{% set build_stage_base_image = "blackdex/rust-musl:arm-musleabi-nightly-2022-01-23" %} {% set build_stage_base_image = "blackdex/rust-musl:arm-musleabi-stable-1.61.0" %}
{% set runtime_stage_base_image = "balenalib/rpi-alpine:3.15" %} {% set runtime_stage_base_image = "balenalib/rpi-alpine:3.16" %}
{% set package_arch_target = "arm-unknown-linux-musleabi" %} {% set package_arch_target = "arm-unknown-linux-musleabi" %}
{% elif "arm64" in target_file %} {% elif "arm64" in target_file %}
{% set build_stage_base_image = "blackdex/rust-musl:aarch64-musl-nightly-2022-01-23" %} {% set build_stage_base_image = "blackdex/rust-musl:aarch64-musl-stable-1.61.0" %}
{% set runtime_stage_base_image = "balenalib/aarch64-alpine:3.15" %} {% set runtime_stage_base_image = "balenalib/aarch64-alpine:3.16" %}
{% set package_arch_target = "aarch64-unknown-linux-musl" %} {% set package_arch_target = "aarch64-unknown-linux-musl" %}
{% endif %} {% endif %}
{% elif "amd64" in target_file %} {% elif "amd64" in target_file %}
{% set runtime_stage_base_image = "debian:buster-slim" %} {% set runtime_stage_base_image = "debian:bullseye-slim" %}
{% elif "arm64" in target_file %} {% elif "arm64" in target_file %}
{% set runtime_stage_base_image = "balenalib/aarch64-debian:buster" %} {% set runtime_stage_base_image = "balenalib/aarch64-debian:bullseye" %}
{% set package_arch_name = "arm64" %} {% set package_arch_name = "arm64" %}
{% set package_arch_target = "aarch64-unknown-linux-gnu" %} {% set package_arch_target = "aarch64-unknown-linux-gnu" %}
{% set package_cross_compiler = "aarch64-linux-gnu" %} {% set package_cross_compiler = "aarch64-linux-gnu" %}
{% elif "armv6" in target_file %} {% elif "armv6" in target_file %}
{% set runtime_stage_base_image = "balenalib/rpi-debian:buster" %} {% set runtime_stage_base_image = "balenalib/rpi-debian:bullseye" %}
{% set package_arch_name = "armel" %} {% set package_arch_name = "armel" %}
{% set package_arch_target = "arm-unknown-linux-gnueabi" %} {% set package_arch_target = "arm-unknown-linux-gnueabi" %}
{% set package_cross_compiler = "arm-linux-gnueabi" %} {% set package_cross_compiler = "arm-linux-gnueabi" %}
{% elif "armv7" in target_file %} {% elif "armv7" in target_file %}
{% set runtime_stage_base_image = "balenalib/armv7hf-debian:buster" %} {% set runtime_stage_base_image = "balenalib/armv7hf-debian:bullseye" %}
{% set package_arch_name = "armhf" %} {% set package_arch_name = "armhf" %}
{% set package_arch_target = "armv7-unknown-linux-gnueabihf" %} {% set package_arch_target = "armv7-unknown-linux-gnueabihf" %}
{% set package_cross_compiler = "arm-linux-gnueabihf" %} {% set package_cross_compiler = "arm-linux-gnueabihf" %}
@ -59,8 +59,8 @@
# https://docs.docker.com/develop/develop-images/multistage-build/ # https://docs.docker.com/develop/develop-images/multistage-build/
# https://whitfin.io/speeding-up-rust-docker-builds/ # https://whitfin.io/speeding-up-rust-docker-builds/
####################### VAULT BUILD IMAGE ####################### ####################### VAULT BUILD IMAGE #######################
{% set vault_version = "2.25.1" %} {% set vault_version = "v2022.6.2" %}
{% set vault_image_digest = "sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965" %} {% set vault_image_digest = "sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70" %}
# The web-vault digest specifies a particular web-vault build on Docker Hub. # The web-vault digest specifies a particular web-vault build on Docker Hub.
# Using the digest instead of the tag name provides better security, # Using the digest instead of the tag name provides better security,
# as the digest of an image is immutable, whereas a tag name can later # as the digest of an image is immutable, whereas a tag name can later
@ -70,13 +70,13 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v{{ vault_version }} # $ docker pull vaultwarden/web-vault:{{ vault_version }}
# $ docker image inspect --format "{{ '{{' }}.RepoDigests}}" vaultwarden/web-vault:v{{ vault_version }} # $ docker image inspect --format "{{ '{{' }}.RepoDigests}}" vaultwarden/web-vault:{{ vault_version }}
# [vaultwarden/web-vault@{{ vault_image_digest }}] # [vaultwarden/web-vault@{{ vault_image_digest }}]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{ '{{' }}.RepoTags}}" vaultwarden/web-vault@{{ vault_image_digest }} # $ docker image inspect --format "{{ '{{' }}.RepoTags}}" vaultwarden/web-vault@{{ vault_image_digest }}
# [vaultwarden/web-vault:v{{ vault_version }}] # [vaultwarden/web-vault:{{ vault_version }}]
# #
FROM vaultwarden/web-vault@{{ vault_image_digest }} as vault FROM vaultwarden/web-vault@{{ vault_image_digest }} as vault
@ -93,22 +93,15 @@ ENV DEBIAN_FRONTEND=noninteractive \
CARGO_HOME="/root/.cargo" \ CARGO_HOME="/root/.cargo" \
USER="root" USER="root"
{# {% if "alpine" not in target_file and "buildx" in target_file %}
# Debian based Buildx builds can use some special apt caching to speedup building.
# By default Debian based images have some rules to keep docker builds clean, we need to remove this.
# See: https://hub.docker.com/r/docker/dockerfile
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache
{% endif %} #}
# Create CARGO_HOME folder and don't download rust docs # Create CARGO_HOME folder and don't download rust docs
RUN {{ mount_rust_cache -}} mkdir -pv "${CARGO_HOME}" \ RUN {{ mount_rust_cache -}} mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
{% if "alpine" in target_file %} {% if "alpine" in target_file %}
ENV RUSTFLAGS='-C link-arg=-s' {% if "armv6" in target_file %}
{% if "armv7" in target_file %} # To be able to build the armv6 image with mimalloc we need to specifically specify the libatomic.a file location
{#- https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html -#} ENV RUSTFLAGS='-Clink-arg=/usr/local/musl/{{ package_arch_target }}/lib/libatomic.a'
ENV CFLAGS_armv7_unknown_linux_musleabihf="-mfpu=vfpv3-d16"
{% endif %} {% endif %}
{% elif "arm" in target_file %} {% elif "arm" in target_file %}
# #
@ -163,7 +156,12 @@ RUN {{ mount_rust_cache -}} rustup target add {{ package_arch_target }}
{% endif %} {% endif %}
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
{% if "alpine" in target_file %}
# Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
{% else %}
ARG DB=sqlite,mysql,postgresql ARG DB=sqlite,mysql,postgresql
{% endif %}
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -182,21 +180,15 @@ RUN touch src/main.rs
# your actual source files being built # your actual source files being built
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN {{ mount_rust_cache -}} cargo build --features ${DB} --release{{ package_arch_target_param }} RUN {{ mount_rust_cache -}} cargo build --features ${DB} --release{{ package_arch_target_param }}
{% if "alpine" in target_file %}
{% if "armv7" in target_file %}
# hadolint ignore=DL3059
RUN musl-strip target/{{ package_arch_target }}/release/vaultwarden
{% endif %}
{% endif %}
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM {{ runtime_stage_base_image }} FROM {{ runtime_stage_base_image }}
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
{%- if "alpine" in runtime_stage_base_image %} \ {%- if "alpine" in runtime_stage_base_image %} \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
{% endif %} {% endif %}
@ -214,7 +206,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
{% else %} {% else %}
&& apt-get update && apt-get install -y \ && apt-get update && apt-get install -y \
@ -222,13 +213,20 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
{% endif %} {% endif %}
{% if "armv6" in target_file and "alpine" not in target_file %}
# In the Balena Bullseye images for armv6/rpi-debian there is a missing symlink.
# This symlink was there in the buster images, and for some reason this is needed.
# hadolint ignore=DL3059
RUN ln -v -s /lib/ld-linux-armhf.so.3 /lib/ld-linux.so.3
{% endif -%}
{% if "amd64" not in target_file %} {% if "amd64" not in target_file %}
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-end" ] RUN [ "cross-build-end" ]
@ -241,7 +239,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
{% if package_arch_target is defined %} {% if package_arch_target is defined %}
COPY --from=build /app/target/{{ package_arch_target }}/release/vaultwarden . COPY --from=build /app/target/{{ package_arch_target }}/release/vaultwarden .
@ -254,6 +251,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/amd64/Dockerfile

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -87,11 +87,11 @@ RUN cargo build --features ${DB} --release
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM debian:buster-slim FROM debian:bullseye-slim
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# Create data folder and Install needed libraries # Create data folder and Install needed libraries
@ -101,7 +101,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -115,7 +114,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/release/vaultwarden . COPY --from=build /app/target/release/vaultwarden .
@ -124,6 +122,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

28
docker/amd64/Dockerfile.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:x86_64-musl-nightly-2022-01-23 as build FROM blackdex/rust-musl:x86_64-musl-stable-1.61.0 as build
@ -44,7 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN mkdir -pv "${CARGO_HOME}" \ RUN mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN rustup target add x86_64-unknown-linux-musl RUN rustup target add x86_64-unknown-linux-musl
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +81,11 @@ RUN cargo build --features ${DB} --release --target=x86_64-unknown-linux-musl
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM alpine:3.15 FROM alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -96,7 +96,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
@ -107,7 +106,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/x86_64-unknown-linux-musl/release/vaultwarden . COPY --from=build /app/target/x86_64-unknown-linux-musl/release/vaultwarden .
@ -116,6 +114,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/amd64/Dockerfile.buildx

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -87,11 +87,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM debian:buster-slim FROM debian:bullseye-slim
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# Create data folder and Install needed libraries # Create data folder and Install needed libraries
@ -101,7 +101,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -115,7 +114,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/release/vaultwarden . COPY --from=build /app/target/release/vaultwarden .
@ -124,6 +122,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

28
docker/amd64/Dockerfile.buildx.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:x86_64-musl-nightly-2022-01-23 as build FROM blackdex/rust-musl:x86_64-musl-stable-1.61.0 as build
@ -44,7 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add x86_64-unknown-linux-musl RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add x86_64-unknown-linux-musl
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +81,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM alpine:3.15 FROM alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -96,7 +96,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
@ -107,7 +106,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/x86_64-unknown-linux-musl/release/vaultwarden . COPY --from=build /app/target/x86_64-unknown-linux-musl/release/vaultwarden .
@ -116,6 +114,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/arm64/Dockerfile

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN cargo build --features ${DB} --release --target=aarch64-unknown-linux-gnu
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/aarch64-debian:buster FROM balenalib/aarch64-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,7 +123,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -139,7 +138,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/aarch64-unknown-linux-gnu/release/vaultwarden . COPY --from=build /app/target/aarch64-unknown-linux-gnu/release/vaultwarden .
@ -148,6 +146,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

28
docker/arm64/Dockerfile.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:aarch64-musl-nightly-2022-01-23 as build FROM blackdex/rust-musl:aarch64-musl-stable-1.61.0 as build
@ -44,7 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN mkdir -pv "${CARGO_HOME}" \ RUN mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN rustup target add aarch64-unknown-linux-musl RUN rustup target add aarch64-unknown-linux-musl
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +81,11 @@ RUN cargo build --features ${DB} --release --target=aarch64-unknown-linux-musl
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/aarch64-alpine:3.15 FROM balenalib/aarch64-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -98,7 +98,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -111,7 +110,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/aarch64-unknown-linux-musl/release/vaultwarden . COPY --from=build /app/target/aarch64-unknown-linux-musl/release/vaultwarden .
@ -120,6 +118,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/arm64/Dockerfile.buildx

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/aarch64-debian:buster FROM balenalib/aarch64-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,7 +123,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -139,7 +138,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/aarch64-unknown-linux-gnu/release/vaultwarden . COPY --from=build /app/target/aarch64-unknown-linux-gnu/release/vaultwarden .
@ -148,6 +146,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

28
docker/arm64/Dockerfile.buildx.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:aarch64-musl-nightly-2022-01-23 as build FROM blackdex/rust-musl:aarch64-musl-stable-1.61.0 as build
@ -44,7 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add aarch64-unknown-linux-musl RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add aarch64-unknown-linux-musl
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +81,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/aarch64-alpine:3.15 FROM balenalib/aarch64-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -98,7 +98,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -111,7 +110,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/aarch64-unknown-linux-musl/release/vaultwarden . COPY --from=build /app/target/aarch64-unknown-linux-musl/release/vaultwarden .
@ -120,6 +118,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

31
docker/armv6/Dockerfile

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN cargo build --features ${DB} --release --target=arm-unknown-linux-gnueabi
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/rpi-debian:buster FROM balenalib/rpi-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,12 +123,16 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# In the Balena Bullseye images for armv6/rpi-debian there is a missing symlink.
# This symlink was there in the buster images, and for some reason this is needed.
# hadolint ignore=DL3059
RUN ln -v -s /lib/ld-linux-armhf.so.3 /lib/ld-linux.so.3
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-end" ] RUN [ "cross-build-end" ]
@ -139,7 +143,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/arm-unknown-linux-gnueabi/release/vaultwarden . COPY --from=build /app/target/arm-unknown-linux-gnueabi/release/vaultwarden .
@ -148,6 +151,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

30
docker/armv6/Dockerfile.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:arm-musleabi-nightly-2022-01-23 as build FROM blackdex/rust-musl:arm-musleabi-stable-1.61.0 as build
@ -44,7 +44,8 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN mkdir -pv "${CARGO_HOME}" \ RUN mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s' # To be able to build the armv6 image with mimalloc we need to specifically specify the libatomic.a file location
ENV RUSTFLAGS='-Clink-arg=/usr/local/musl/arm-unknown-linux-musleabi/lib/libatomic.a'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +59,8 @@ COPY ./build.rs ./build.rs
RUN rustup target add arm-unknown-linux-musleabi RUN rustup target add arm-unknown-linux-musleabi
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +83,11 @@ RUN cargo build --features ${DB} --release --target=arm-unknown-linux-musleabi
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/rpi-alpine:3.15 FROM balenalib/rpi-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -98,7 +100,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -111,7 +112,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/arm-unknown-linux-musleabi/release/vaultwarden . COPY --from=build /app/target/arm-unknown-linux-musleabi/release/vaultwarden .
@ -120,6 +120,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

31
docker/armv6/Dockerfile.buildx

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/rpi-debian:buster FROM balenalib/rpi-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,12 +123,16 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# In the Balena Bullseye images for armv6/rpi-debian there is a missing symlink.
# This symlink was there in the buster images, and for some reason this is needed.
# hadolint ignore=DL3059
RUN ln -v -s /lib/ld-linux-armhf.so.3 /lib/ld-linux.so.3
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-end" ] RUN [ "cross-build-end" ]
@ -139,7 +143,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/arm-unknown-linux-gnueabi/release/vaultwarden . COPY --from=build /app/target/arm-unknown-linux-gnueabi/release/vaultwarden .
@ -148,6 +151,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

30
docker/armv6/Dockerfile.buildx.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:arm-musleabi-nightly-2022-01-23 as build FROM blackdex/rust-musl:arm-musleabi-stable-1.61.0 as build
@ -44,7 +44,8 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s' # To be able to build the armv6 image with mimalloc we need to specifically specify the libatomic.a file location
ENV RUSTFLAGS='-Clink-arg=/usr/local/musl/arm-unknown-linux-musleabi/lib/libatomic.a'
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -58,7 +59,8 @@ COPY ./build.rs ./build.rs
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add arm-unknown-linux-musleabi RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add arm-unknown-linux-musleabi
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -81,11 +83,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/rpi-alpine:3.15 FROM balenalib/rpi-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -98,7 +100,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -111,7 +112,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/arm-unknown-linux-musleabi/release/vaultwarden . COPY --from=build /app/target/arm-unknown-linux-musleabi/release/vaultwarden .
@ -120,6 +120,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/armv7/Dockerfile

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN cargo build --features ${DB} --release --target=armv7-unknown-linux-gnueabih
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/armv7hf-debian:buster FROM balenalib/armv7hf-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,7 +123,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -139,7 +138,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/armv7-unknown-linux-gnueabihf/release/vaultwarden . COPY --from=build /app/target/armv7-unknown-linux-gnueabihf/release/vaultwarden .
@ -148,6 +146,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

31
docker/armv7/Dockerfile.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:armv7-musleabihf-nightly-2022-01-23 as build FROM blackdex/rust-musl:armv7-musleabihf-stable-1.61.0 as build
@ -44,8 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN mkdir -pv "${CARGO_HOME}" \ RUN mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
ENV CFLAGS_armv7_unknown_linux_musleabihf="-mfpu=vfpv3-d16"
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -59,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN rustup target add armv7-unknown-linux-musleabihf RUN rustup target add armv7-unknown-linux-musleabihf
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -78,17 +77,15 @@ RUN touch src/main.rs
# your actual source files being built # your actual source files being built
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN cargo build --features ${DB} --release --target=armv7-unknown-linux-musleabihf RUN cargo build --features ${DB} --release --target=armv7-unknown-linux-musleabihf
# hadolint ignore=DL3059
RUN musl-strip target/armv7-unknown-linux-musleabihf/release/vaultwarden
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/armv7hf-alpine:3.15 FROM balenalib/armv7hf-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -101,7 +98,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -114,7 +110,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/armv7-unknown-linux-musleabihf/release/vaultwarden . COPY --from=build /app/target/armv7-unknown-linux-musleabihf/release/vaultwarden .
@ -123,6 +118,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

26
docker/armv7/Dockerfile.buildx

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM rust:1.58-buster as build FROM rust:1.61-bullseye as build
@ -107,11 +107,11 @@ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/armv7hf-debian:buster FROM balenalib/armv7hf-debian:bullseye
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_PORT=80 \ ROCKET_ADDRESS=0.0.0.0 \
ROCKET_WORKERS=10 ROCKET_PORT=80
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN [ "cross-build-start" ] RUN [ "cross-build-start" ]
@ -123,7 +123,6 @@ RUN mkdir /data \
openssl \ openssl \
ca-certificates \ ca-certificates \
curl \ curl \
dumb-init \
libmariadb-dev-compat \ libmariadb-dev-compat \
libpq5 \ libpq5 \
&& apt-get clean \ && apt-get clean \
@ -139,7 +138,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/armv7-unknown-linux-gnueabihf/release/vaultwarden . COPY --from=build /app/target/armv7-unknown-linux-gnueabihf/release/vaultwarden .
@ -148,6 +146,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

31
docker/armv7/Dockerfile.buildx.alpine

@ -16,18 +16,18 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags, # - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to. # click the tag name to view the digest of the image it currently points to.
# - From the command line: # - From the command line:
# $ docker pull vaultwarden/web-vault:v2.25.1 # $ docker pull vaultwarden/web-vault:v2022.6.2
# $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2.25.1 # $ docker image inspect --format "{{.RepoDigests}}" vaultwarden/web-vault:v2022.6.2
# [vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965] # [vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70]
# #
# - Conversely, to get the tag name from the digest: # - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 # $ docker image inspect --format "{{.RepoTags}}" vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70
# [vaultwarden/web-vault:v2.25.1] # [vaultwarden/web-vault:v2022.6.2]
# #
FROM vaultwarden/web-vault@sha256:4f9b7a6b0eaceb511cca8c6a5ed5aa92f527960b1b33d86fbbfd4e5795943965 as vault FROM vaultwarden/web-vault@sha256:1dfda41cbddeac5bc59540261fff8defcac37170b5ba02d29c12fa1215498f70 as vault
########################## BUILD IMAGE ########################## ########################## BUILD IMAGE ##########################
FROM blackdex/rust-musl:armv7-musleabihf-nightly-2022-01-23 as build FROM blackdex/rust-musl:armv7-musleabihf-stable-1.61.0 as build
@ -44,8 +44,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \ RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry mkdir -pv "${CARGO_HOME}" \
&& rustup set profile minimal && rustup set profile minimal
ENV RUSTFLAGS='-C link-arg=-s'
ENV CFLAGS_armv7_unknown_linux_musleabihf="-mfpu=vfpv3-d16"
# Creates a dummy project used to grab dependencies # Creates a dummy project used to grab dependencies
RUN USER=root cargo new --bin /app RUN USER=root cargo new --bin /app
@ -59,7 +57,8 @@ COPY ./build.rs ./build.rs
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add armv7-unknown-linux-musleabihf RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry rustup target add armv7-unknown-linux-musleabihf
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder
@ -78,17 +77,15 @@ RUN touch src/main.rs
# your actual source files being built # your actual source files being built
# hadolint ignore=DL3059 # hadolint ignore=DL3059
RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry cargo build --features ${DB} --release --target=armv7-unknown-linux-musleabihf RUN --mount=type=cache,target=/root/.cargo/git --mount=type=cache,target=/root/.cargo/registry cargo build --features ${DB} --release --target=armv7-unknown-linux-musleabihf
# hadolint ignore=DL3059
RUN musl-strip target/armv7-unknown-linux-musleabihf/release/vaultwarden
######################## RUNTIME IMAGE ######################## ######################## RUNTIME IMAGE ########################
# Create a new stage with a minimal image # Create a new stage with a minimal image
# because we already have a binary built # because we already have a binary built
FROM balenalib/armv7hf-alpine:3.15 FROM balenalib/armv7hf-alpine:3.16
ENV ROCKET_ENV="staging" \ ENV ROCKET_PROFILE="release" \
ROCKET_ADDRESS=0.0.0.0 \
ROCKET_PORT=80 \ ROCKET_PORT=80 \
ROCKET_WORKERS=10 \
SSL_CERT_DIR=/etc/ssl/certs SSL_CERT_DIR=/etc/ssl/certs
@ -101,7 +98,6 @@ RUN mkdir /data \
openssl \ openssl \
tzdata \ tzdata \
curl \ curl \
dumb-init \
ca-certificates ca-certificates
# hadolint ignore=DL3059 # hadolint ignore=DL3059
@ -114,7 +110,6 @@ EXPOSE 3012
# Copies the files from the context (Rocket.toml file and web-vault) # Copies the files from the context (Rocket.toml file and web-vault)
# and the binary from the "build" stage to the current stage # and the binary from the "build" stage to the current stage
WORKDIR / WORKDIR /
COPY Rocket.toml .
COPY --from=vault /web-vault ./web-vault COPY --from=vault /web-vault ./web-vault
COPY --from=build /app/target/armv7-unknown-linux-musleabihf/release/vaultwarden . COPY --from=build /app/target/armv7-unknown-linux-musleabihf/release/vaultwarden .
@ -123,6 +118,4 @@ COPY docker/start.sh /start.sh
HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"] HEALTHCHECK --interval=60s --timeout=10s CMD ["/healthcheck.sh"]
# Configures the startup!
ENTRYPOINT ["/usr/bin/dumb-init", "--"]
CMD ["/start.sh"] CMD ["/start.sh"]

4
docker/healthcheck.sh

@ -2,8 +2,8 @@
# Use the value of the corresponding env var (if present), # Use the value of the corresponding env var (if present),
# or a default value otherwise. # or a default value otherwise.
: ${DATA_FOLDER:="data"} : "${DATA_FOLDER:="data"}"
: ${ROCKET_PORT:="80"} : "${ROCKET_PORT:="80"}"
CONFIG_FILE="${DATA_FOLDER}"/config.json CONFIG_FILE="${DATA_FOLDER}"/config.json

8
docker/start.sh

@ -9,15 +9,15 @@ fi
if [ -d /etc/vaultwarden.d ]; then if [ -d /etc/vaultwarden.d ]; then
for f in /etc/vaultwarden.d/*.sh; do for f in /etc/vaultwarden.d/*.sh; do
if [ -r $f ]; then if [ -r "${f}" ]; then
. $f . "${f}"
fi fi
done done
elif [ -d /etc/bitwarden_rs.d ]; then elif [ -d /etc/bitwarden_rs.d ]; then
echo "### You are using the old /etc/bitwarden_rs.d script directory, please migrate to /etc/vaultwarden.d ###" echo "### You are using the old /etc/bitwarden_rs.d script directory, please migrate to /etc/vaultwarden.d ###"
for f in /etc/bitwarden_rs.d/*.sh; do for f in /etc/bitwarden_rs.d/*.sh; do
if [ -r $f ]; then if [ -r "${f}" ]; then
. $f . "${f}"
fi fi
done done
fi fi

0
migrations/mysql/2022-03-02-210038_update_devices_primary_key/down.sql

4
migrations/mysql/2022-03-02-210038_update_devices_primary_key/up.sql

@ -0,0 +1,4 @@
-- First remove the previous primary key
ALTER TABLE devices DROP PRIMARY KEY;
-- Add a new combined one
ALTER TABLE devices ADD PRIMARY KEY (uuid, user_uuid);

0
migrations/postgresql/2022-03-02-210038_update_devices_primary_key/down.sql

4
migrations/postgresql/2022-03-02-210038_update_devices_primary_key/up.sql

@ -0,0 +1,4 @@
-- First remove the previous primary key
ALTER TABLE devices DROP CONSTRAINT devices_pkey;
-- Add a new combined one
ALTER TABLE devices ADD PRIMARY KEY (uuid, user_uuid);

0
migrations/sqlite/2022-03-02-210038_update_devices_primary_key/down.sql

23
migrations/sqlite/2022-03-02-210038_update_devices_primary_key/up.sql

@ -0,0 +1,23 @@
-- Create new devices table with primary keys on both uuid and user_uuid
CREATE TABLE devices_new (
uuid TEXT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
user_uuid TEXT NOT NULL,
name TEXT NOT NULL,
atype INTEGER NOT NULL,
push_token TEXT,
refresh_token TEXT NOT NULL,
twofactor_remember TEXT,
PRIMARY KEY(uuid, user_uuid),
FOREIGN KEY(user_uuid) REFERENCES users(uuid)
);
-- Transfer current data to new table
INSERT INTO devices_new SELECT * FROM devices;
-- Drop the old table
DROP TABLE devices;
-- Rename the new table to the original name
ALTER TABLE devices_new RENAME TO devices;

2
rust-toolchain

@ -1 +1 @@
nightly-2022-01-23 1.61.0

8
rustfmt.toml

@ -1,7 +1,7 @@
version = "Two" # version = "Two"
edition = "2018" edition = "2021"
max_width = 120 max_width = 120
newline_style = "Unix" newline_style = "Unix"
use_small_heuristics = "Off" use_small_heuristics = "Off"
struct_lit_single_line = false # struct_lit_single_line = false
overflow_delimited_expr = true # overflow_delimited_expr = true

300
src/api/admin.rs

@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
use serde_json::Value; use serde_json::Value;
use std::env; use std::env;
use rocket::serde::json::Json;
use rocket::{ use rocket::{
http::{Cookie, Cookies, SameSite, Status}, form::Form,
request::{self, FlashMessage, Form, FromRequest, Outcome, Request}, http::{Cookie, CookieJar, SameSite, Status},
response::{content::Html, Flash, Redirect}, request::{self, FlashMessage, FromRequest, Outcome, Request},
response::{content::RawHtml as Html, Flash, Redirect},
Route, Route,
}; };
use rocket_contrib::json::Json;
use crate::{ use crate::{
api::{ApiResult, EmptyResult, JsonResult, NumberOrString}, api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
@ -24,6 +25,8 @@ use crate::{
CONFIG, VERSION, CONFIG, VERSION,
}; };
use futures::{stream, stream::StreamExt};
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
if !CONFIG.disable_admin_token() && !CONFIG.is_admin_token_set() { if !CONFIG.disable_admin_token() && !CONFIG.is_admin_token_set() {
return routes![admin_disabled]; return routes![admin_disabled];
@ -76,6 +79,7 @@ fn admin_disabled() -> &'static str {
const COOKIE_NAME: &str = "VW_ADMIN"; const COOKIE_NAME: &str = "VW_ADMIN";
const ADMIN_PATH: &str = "/admin"; const ADMIN_PATH: &str = "/admin";
const DT_FMT: &str = "%Y-%m-%d %H:%M:%S %Z";
const BASE_TEMPLATE: &str = "admin/base"; const BASE_TEMPLATE: &str = "admin/base";
@ -85,10 +89,11 @@ fn admin_path() -> String {
struct Referer(Option<String>); struct Referer(Option<String>);
impl<'a, 'r> FromRequest<'a, 'r> for Referer { #[rocket::async_trait]
impl<'r> FromRequest<'r> for Referer {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string))) Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
} }
} }
@ -96,10 +101,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
#[derive(Debug)] #[derive(Debug)]
struct IpHeader(Option<String>); struct IpHeader(Option<String>);
impl<'a, 'r> FromRequest<'a, 'r> for IpHeader { #[rocket::async_trait]
impl<'r> FromRequest<'r> for IpHeader {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
if req.headers().get_one(&CONFIG.ip_header()).is_some() { if req.headers().get_one(&CONFIG.ip_header()).is_some() {
Outcome::Success(IpHeader(Some(CONFIG.ip_header()))) Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
} else if req.headers().get_one("X-Client-IP").is_some() { } else if req.headers().get_one("X-Client-IP").is_some() {
@ -136,9 +142,9 @@ fn admin_url(referer: Referer) -> String {
} }
#[get("/", rank = 2)] #[get("/", rank = 2)]
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> { fn admin_login(flash: Option<FlashMessage<'_>>) -> ApiResult<Html<String>> {
// If there is an error, show it // If there is an error, show it
let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg())); let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
let json = json!({ let json = json!({
"page_content": "admin/login", "page_content": "admin/login",
"version": VERSION, "version": VERSION,
@ -159,7 +165,7 @@ struct LoginForm {
#[post("/", data = "<data>")] #[post("/", data = "<data>")]
fn post_admin_login( fn post_admin_login(
data: Form<LoginForm>, data: Form<LoginForm>,
mut cookies: Cookies, cookies: &CookieJar<'_>,
ip: ClientIp, ip: ClientIp,
referer: Referer, referer: Referer,
) -> Result<Redirect, Flash<Redirect>> { ) -> Result<Redirect, Flash<Redirect>> {
@ -180,7 +186,7 @@ fn post_admin_login(
let cookie = Cookie::build(COOKIE_NAME, jwt) let cookie = Cookie::build(COOKIE_NAME, jwt)
.path(admin_path()) .path(admin_path())
.max_age(time::Duration::minutes(20)) .max_age(rocket::time::Duration::minutes(20))
.same_site(SameSite::Strict) .same_site(SameSite::Strict)
.http_only(true) .http_only(true)
.finish(); .finish();
@ -250,8 +256,8 @@ struct InviteData {
email: String, email: String,
} }
fn get_user_or_404(uuid: &str, conn: &DbConn) -> ApiResult<User> { async fn get_user_or_404(uuid: &str, conn: &DbConn) -> ApiResult<User> {
if let Some(user) = User::find_by_uuid(uuid, conn) { if let Some(user) = User::find_by_uuid(uuid, conn).await {
Ok(user) Ok(user)
} else { } else {
err_code!("User doesn't exist", Status::NotFound.code); err_code!("User doesn't exist", Status::NotFound.code);
@ -259,128 +265,135 @@ fn get_user_or_404(uuid: &str, conn: &DbConn) -> ApiResult<User> {
} }
#[post("/invite", data = "<data>")] #[post("/invite", data = "<data>")]
fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult { async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
let data: InviteData = data.into_inner(); let data: InviteData = data.into_inner();
let email = data.email.clone(); let email = data.email.clone();
if User::find_by_mail(&data.email, &conn).is_some() { if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code) err_code!("User already exists", Status::Conflict.code)
} }
let mut user = User::new(email); let mut user = User::new(email);
// TODO: After try_blocks is stabilized, this can be made more readable async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
// See: https://github.com/rust-lang/rust/issues/31436
(|| {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_invite(&user.email, &user.uuid, None, None, &CONFIG.invitation_org_name(), None)?; mail::send_invite(&user.email, &user.uuid, None, None, &CONFIG.invitation_org_name(), None).await
} else { } else {
let invitation = Invitation::new(user.email.clone()); let invitation = Invitation::new(user.email.clone());
invitation.save(&conn)?; invitation.save(conn).await
} }
}
user.save(&conn) _generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
})() user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&conn))) Ok(Json(user.to_json(&conn).await))
} }
#[post("/test/smtp", data = "<data>")] #[post("/test/smtp", data = "<data>")]
fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult { async fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
let data: InviteData = data.into_inner(); let data: InviteData = data.into_inner();
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_test(&data.email) mail::send_test(&data.email).await
} else { } else {
err!("Mail is not enabled") err!("Mail is not enabled")
} }
} }
#[get("/logout")] #[get("/logout")]
fn logout(mut cookies: Cookies, referer: Referer) -> Redirect { fn logout(cookies: &CookieJar<'_>, referer: Referer) -> Redirect {
cookies.remove(Cookie::named(COOKIE_NAME)); cookies.remove(Cookie::build(COOKIE_NAME, "").path(admin_path()).finish());
Redirect::to(admin_url(referer)) Redirect::to(admin_url(referer))
} }
#[get("/users")] #[get("/users")]
fn get_users_json(_token: AdminToken, conn: DbConn) -> Json<Value> { async fn get_users_json(_token: AdminToken, conn: DbConn) -> Json<Value> {
let users = User::get_all(&conn); let users_json = stream::iter(User::get_all(&conn).await)
let users_json: Vec<Value> = users.iter().map(|u| u.to_json(&conn)).collect(); .then(|u| async {
let u = u; // Move out this single variable
let mut usr = u.to_json(&conn).await;
usr["UserEnabled"] = json!(u.enabled);
usr["CreatedAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr
})
.collect::<Vec<Value>>()
.await;
Json(Value::Array(users_json)) Json(Value::Array(users_json))
} }
#[get("/users/overview")] #[get("/users/overview")]
fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> { async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let users = User::get_all(&conn); let users_json = stream::iter(User::get_all(&conn).await)
let dt_fmt = "%Y-%m-%d %H:%M:%S %Z"; .then(|u| async {
let users_json: Vec<Value> = users let u = u; // Move out this single variable
.iter() let mut usr = u.to_json(&conn).await;
.map(|u| { usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn).await);
let mut usr = u.to_json(&conn); usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn).await);
usr["cipher_count"] = json!(Cipher::count_owned_by_user(&u.uuid, &conn)); usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn).await as i32));
usr["attachment_count"] = json!(Attachment::count_by_user(&u.uuid, &conn));
usr["attachment_size"] = json!(get_display_size(Attachment::size_by_user(&u.uuid, &conn) as i32));
usr["user_enabled"] = json!(u.enabled); usr["user_enabled"] = json!(u.enabled);
usr["created_at"] = json!(format_naive_datetime_local(&u.created_at, dt_fmt)); usr["created_at"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
usr["last_active"] = match u.last_active(&conn) { usr["last_active"] = match u.last_active(&conn).await {
Some(dt) => json!(format_naive_datetime_local(&dt, dt_fmt)), Some(dt) => json!(format_naive_datetime_local(&dt, DT_FMT)),
None => json!("Never"), None => json!("Never"),
}; };
usr usr
}) })
.collect(); .collect::<Vec<Value>>()
.await;
let text = AdminTemplateData::with_data("admin/users", json!(users_json)).render()?; let text = AdminTemplateData::with_data("admin/users", json!(users_json)).render()?;
Ok(Html(text)) Ok(Html(text))
} }
#[get("/users/<uuid>")] #[get("/users/<uuid>")]
fn get_user_json(uuid: String, _token: AdminToken, conn: DbConn) -> JsonResult { async fn get_user_json(uuid: String, _token: AdminToken, conn: DbConn) -> JsonResult {
let user = get_user_or_404(&uuid, &conn)?; let u = get_user_or_404(&uuid, &conn).await?;
let mut usr = u.to_json(&conn).await;
Ok(Json(user.to_json(&conn))) usr["UserEnabled"] = json!(u.enabled);
usr["CreatedAt"] = json!(format_naive_datetime_local(&u.created_at, DT_FMT));
Ok(Json(usr))
} }
#[post("/users/<uuid>/delete")] #[post("/users/<uuid>/delete")]
fn delete_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn delete_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let user = get_user_or_404(&uuid, &conn)?; let user = get_user_or_404(&uuid, &conn).await?;
user.delete(&conn) user.delete(&conn).await
} }
#[post("/users/<uuid>/deauth")] #[post("/users/<uuid>/deauth")]
fn deauth_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn deauth_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&uuid, &conn)?; let mut user = get_user_or_404(&uuid, &conn).await?;
Device::delete_all_by_user(&user.uuid, &conn)?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
user.save(&conn) user.save(&conn).await
} }
#[post("/users/<uuid>/disable")] #[post("/users/<uuid>/disable")]
fn disable_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn disable_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&uuid, &conn)?; let mut user = get_user_or_404(&uuid, &conn).await?;
Device::delete_all_by_user(&user.uuid, &conn)?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
user.enabled = false; user.enabled = false;
user.save(&conn) user.save(&conn).await
} }
#[post("/users/<uuid>/enable")] #[post("/users/<uuid>/enable")]
fn enable_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn enable_user(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&uuid, &conn)?; let mut user = get_user_or_404(&uuid, &conn).await?;
user.enabled = true; user.enabled = true;
user.save(&conn) user.save(&conn).await
} }
#[post("/users/<uuid>/remove-2fa")] #[post("/users/<uuid>/remove-2fa")]
fn remove_2fa(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn remove_2fa(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let mut user = get_user_or_404(&uuid, &conn)?; let mut user = get_user_or_404(&uuid, &conn).await?;
TwoFactor::delete_all_by_user(&user.uuid, &conn)?; TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
user.totp_recover = None; user.totp_recover = None;
user.save(&conn) user.save(&conn).await
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -391,10 +404,10 @@ struct UserOrgTypeData {
} }
#[post("/users/org_type", data = "<data>")] #[post("/users/org_type", data = "<data>")]
fn update_user_org_type(data: Json<UserOrgTypeData>, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn update_user_org_type(data: Json<UserOrgTypeData>, _token: AdminToken, conn: DbConn) -> EmptyResult {
let data: UserOrgTypeData = data.into_inner(); let data: UserOrgTypeData = data.into_inner();
let mut user_to_edit = match UserOrganization::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &conn) { let mut user_to_edit = match UserOrganization::find_by_user_and_org(&data.user_uuid, &data.org_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("The specified user isn't member of the organization"), None => err!("The specified user isn't member of the organization"),
}; };
@ -406,45 +419,46 @@ fn update_user_org_type(data: Json<UserOrgTypeData>, _token: AdminToken, conn: D
if user_to_edit.atype == UserOrgType::Owner && new_type != UserOrgType::Owner { if user_to_edit.atype == UserOrgType::Owner && new_type != UserOrgType::Owner {
// Removing owner permmission, check that there are at least another owner // Removing owner permmission, check that there are at least another owner
let num_owners = UserOrganization::find_by_org_and_type(&data.org_uuid, UserOrgType::Owner as i32, &conn).len(); let num_owners =
UserOrganization::find_by_org_and_type(&data.org_uuid, UserOrgType::Owner as i32, &conn).await.len();
if num_owners <= 1 { if num_owners <= 1 {
err!("Can't change the type of the last owner") err!("Can't change the type of the last owner")
} }
} }
user_to_edit.atype = new_type as i32; user_to_edit.atype = new_type;
user_to_edit.save(&conn) user_to_edit.save(&conn).await
} }
#[post("/users/update_revision")] #[post("/users/update_revision")]
fn update_revision_users(_token: AdminToken, conn: DbConn) -> EmptyResult { async fn update_revision_users(_token: AdminToken, conn: DbConn) -> EmptyResult {
User::update_all_revisions(&conn) User::update_all_revisions(&conn).await
} }
#[get("/organizations/overview")] #[get("/organizations/overview")]
fn organizations_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> { async fn organizations_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<String>> {
let organizations = Organization::get_all(&conn); let organizations_json = stream::iter(Organization::get_all(&conn).await)
let organizations_json: Vec<Value> = organizations .then(|o| async {
.iter() let o = o; //Move out this single variable
.map(|o| {
let mut org = o.to_json(); let mut org = o.to_json();
org["user_count"] = json!(UserOrganization::count_by_org(&o.uuid, &conn)); org["user_count"] = json!(UserOrganization::count_by_org(&o.uuid, &conn).await);
org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn)); org["cipher_count"] = json!(Cipher::count_by_org(&o.uuid, &conn).await);
org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn)); org["attachment_count"] = json!(Attachment::count_by_org(&o.uuid, &conn).await);
org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn) as i32)); org["attachment_size"] = json!(get_display_size(Attachment::size_by_org(&o.uuid, &conn).await as i32));
org org
}) })
.collect(); .collect::<Vec<Value>>()
.await;
let text = AdminTemplateData::with_data("admin/organizations", json!(organizations_json)).render()?; let text = AdminTemplateData::with_data("admin/organizations", json!(organizations_json)).render()?;
Ok(Html(text)) Ok(Html(text))
} }
#[post("/organizations/<uuid>/delete")] #[post("/organizations/<uuid>/delete")]
fn delete_organization(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult { async fn delete_organization(uuid: String, _token: AdminToken, conn: DbConn) -> EmptyResult {
let org = Organization::find_by_uuid(&uuid, &conn).map_res("Organization doesn't exist")?; let org = Organization::find_by_uuid(&uuid, &conn).await.map_res("Organization doesn't exist")?;
org.delete(&conn) org.delete(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -462,32 +476,74 @@ struct GitCommit {
sha: String, sha: String,
} }
fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> { async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
let github_api = get_reqwest_client(); let github_api = get_reqwest_client();
Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?) Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
} }
fn has_http_access() -> bool { async fn has_http_access() -> bool {
let http_access = get_reqwest_client(); let http_access = get_reqwest_client();
match http_access.head("https://github.com/dani-garcia/vaultwarden").send() { match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
Ok(r) => r.status().is_success(), Ok(r) => r.status().is_success(),
_ => false, _ => false,
} }
} }
use cached::proc_macro::cached;
/// Cache this function to prevent API call rate limit. Github only allows 60 requests per hour, and we use 3 here already.
/// It will cache this function for 300 seconds (5 minutes) which should prevent the exhaustion of the rate limit.
#[cached(time = 300, sync_writes = true)]
async fn get_release_info(has_http_access: bool, running_within_docker: bool) -> (String, String, String) {
// If the HTTP Check failed, do not even attempt to check for new versions since we were not able to connect with github.com anyway.
if has_http_access {
info!("Running get_release_info!!");
(
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
.await
{
Ok(r) => r.tag_name,
_ => "-".to_string(),
},
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
{
Ok(mut c) => {
c.sha.truncate(8);
c.sha
}
_ => "-".to_string(),
},
// Do not fetch the web-vault version when running within Docker.
// The web-vault version is embedded within the container it self, and should not be updated manually
if running_within_docker {
"-".to_string()
} else {
match get_github_api::<GitRelease>(
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
)
.await
{
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
_ => "-".to_string(),
}
},
)
} else {
("-".to_string(), "-".to_string(), "-".to_string())
}
}
#[get("/diagnostics")] #[get("/diagnostics")]
fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> { async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
use crate::util::read_file_string;
use chrono::prelude::*; use chrono::prelude::*;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
// Get current running versions // Get current running versions
let web_vault_version: WebVaultVersion = let web_vault_version: WebVaultVersion =
match read_file_string(&format!("{}/{}", CONFIG.web_vault_folder(), "vw-version.json")) { match std::fs::read_to_string(&format!("{}/{}", CONFIG.web_vault_folder(), "vw-version.json")) {
Ok(s) => serde_json::from_str(&s)?, Ok(s) => serde_json::from_str(&s)?,
_ => match read_file_string(&format!("{}/{}", CONFIG.web_vault_folder(), "version.json")) { _ => match std::fs::read_to_string(&format!("{}/{}", CONFIG.web_vault_folder(), "version.json")) {
Ok(s) => serde_json::from_str(&s)?, Ok(s) => serde_json::from_str(&s)?,
_ => WebVaultVersion { _ => WebVaultVersion {
version: String::from("Version file missing"), version: String::from("Version file missing"),
@ -497,7 +553,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
// Execute some environment checks // Execute some environment checks
let running_within_docker = is_running_in_docker(); let running_within_docker = is_running_in_docker();
let has_http_access = has_http_access(); let has_http_access = has_http_access().await;
let uses_proxy = env::var_os("HTTP_PROXY").is_some() let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|| env::var_os("http_proxy").is_some() || env::var_os("http_proxy").is_some()
|| env::var_os("HTTPS_PROXY").is_some() || env::var_os("HTTPS_PROXY").is_some()
@ -509,37 +565,8 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
_ => "Could not resolve domain name.".to_string(), _ => "Could not resolve domain name.".to_string(),
}; };
// If the HTTP Check failed, do not even attempt to check for new versions since we were not able to connect with github.com anyway. let (latest_release, latest_commit, latest_web_build) =
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already. get_release_info(has_http_access, running_within_docker).await;
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
(
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
Ok(r) => r.tag_name,
_ => "-".to_string(),
},
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
Ok(mut c) => {
c.sha.truncate(8);
c.sha
}
_ => "-".to_string(),
},
// Do not fetch the web-vault version when running within Docker.
// The web-vault version is embedded within the container it self, and should not be updated manually
if running_within_docker {
"-".to_string()
} else {
match get_github_api::<GitRelease>(
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
) {
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
_ => "-".to_string(),
}
},
)
} else {
("-".to_string(), "-".to_string(), "-".to_string())
};
let ip_header_name = match &ip_header.0 { let ip_header_name = match &ip_header.0 {
Some(h) => h, Some(h) => h,
@ -562,7 +589,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
"ip_header_config": &CONFIG.ip_header(), "ip_header_config": &CONFIG.ip_header(),
"uses_proxy": uses_proxy, "uses_proxy": uses_proxy,
"db_type": *DB_TYPE, "db_type": *DB_TYPE,
"db_version": get_sql_server_version(&conn), "db_version": get_sql_server_version(&conn).await,
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))), "admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
"overrides": &CONFIG.get_overrides().join(", "), "overrides": &CONFIG.get_overrides().join(", "),
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(), "server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
@ -591,9 +618,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
} }
#[post("/config/backup_db")] #[post("/config/backup_db")]
fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult { async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
if *CAN_BACKUP { if *CAN_BACKUP {
backup_database(&conn) backup_database(&conn).await
} else { } else {
err!("Can't back up current DB (Only SQLite supports this feature)"); err!("Can't back up current DB (Only SQLite supports this feature)");
} }
@ -601,28 +628,29 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
pub struct AdminToken {} pub struct AdminToken {}
impl<'a, 'r> FromRequest<'a, 'r> for AdminToken { #[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminToken {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
if CONFIG.disable_admin_token() { if CONFIG.disable_admin_token() {
Outcome::Success(AdminToken {}) Outcome::Success(AdminToken {})
} else { } else {
let mut cookies = request.cookies(); let cookies = request.cookies();
let access_token = match cookies.get(COOKIE_NAME) { let access_token = match cookies.get(COOKIE_NAME) {
Some(cookie) => cookie.value(), Some(cookie) => cookie.value(),
None => return Outcome::Forward(()), // If there is no cookie, redirect to login None => return Outcome::Forward(()), // If there is no cookie, redirect to login
}; };
let ip = match request.guard::<ClientIp>() { let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip.ip, Outcome::Success(ip) => ip.ip,
_ => err_handler!("Error getting Client IP"), _ => err_handler!("Error getting Client IP"),
}; };
if decode_admin(access_token).is_err() { if decode_admin(access_token).is_err() {
// Remove admin cookie // Remove admin cookie
cookies.remove(Cookie::named(COOKIE_NAME)); cookies.remove(Cookie::build(COOKIE_NAME, "").path(admin_path()).finish());
error!("Invalid or expired admin JWT. IP: {}.", ip); error!("Invalid or expired admin JWT. IP: {}.", ip);
return Outcome::Forward(()); return Outcome::Forward(());
} }

212
src/api/core/accounts.rs

@ -1,5 +1,5 @@
use chrono::Utc; use chrono::Utc;
use rocket_contrib::json::Json; use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
@ -62,12 +62,43 @@ struct KeysData {
PublicKey: String, PublicKey: String,
} }
/// Trims whitespace from password hints, and converts blank password hints to `None`.
fn clean_password_hint(password_hint: &Option<String>) -> Option<String> {
match password_hint {
None => None,
Some(h) => match h.trim() {
"" => None,
ht => Some(ht.to_string()),
},
}
}
fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult {
if password_hint.is_some() && !CONFIG.password_hints_allowed() {
err!("Password hints have been disabled by the administrator. Remove the hint and try again.");
}
Ok(())
}
#[post("/accounts/register", data = "<data>")] #[post("/accounts/register", data = "<data>")]
fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult { async fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
let data: RegisterData = data.into_inner().data; let data: RegisterData = data.into_inner().data;
let email = data.Email.to_lowercase(); let email = data.Email.to_lowercase();
let mut user = match User::find_by_mail(&email, &conn) { // Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
// This also prevents issues with very long usernames causing to large JWT's. See #2419
if let Some(ref name) = data.Name {
if name.len() > 50 {
err!("The field Name must be a string with a maximum length of 50.");
}
}
// Check against the password hint setting here so if it fails, the user
// can retry without losing their invitation below.
let password_hint = clean_password_hint(&data.MasterPasswordHint);
enforce_password_hint_setting(&password_hint)?;
let mut user = match User::find_by_mail(&email, &conn).await {
Some(user) => { Some(user) => {
if !user.password_hash.is_empty() { if !user.password_hash.is_empty() {
if CONFIG.is_signup_allowed(&email) { if CONFIG.is_signup_allowed(&email) {
@ -84,13 +115,13 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
} else { } else {
err!("Registration email does not match invite email") err!("Registration email does not match invite email")
} }
} else if Invitation::take(&email, &conn) { } else if Invitation::take(&email, &conn).await {
for mut user_org in UserOrganization::find_invited_by_user(&user.uuid, &conn).iter_mut() { for mut user_org in UserOrganization::find_invited_by_user(&user.uuid, &conn).await.iter_mut() {
user_org.status = UserOrgStatus::Accepted as i32; user_org.status = UserOrgStatus::Accepted as i32;
user_org.save(&conn)?; user_org.save(&conn).await?;
} }
user user
} else if EmergencyAccess::find_invited_by_grantee_email(&email, &conn).is_some() { } else if EmergencyAccess::find_invited_by_grantee_email(&email, &conn).await.is_some() {
user user
} else if CONFIG.is_signup_allowed(&email) { } else if CONFIG.is_signup_allowed(&email) {
err!("Account with this email already exists") err!("Account with this email already exists")
@ -102,7 +133,7 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
// Order is important here; the invitation check must come first // Order is important here; the invitation check must come first
// because the vaultwarden admin can invite anyone, regardless // because the vaultwarden admin can invite anyone, regardless
// of other signup restrictions. // of other signup restrictions.
if Invitation::take(&email, &conn) || CONFIG.is_signup_allowed(&email) { if Invitation::take(&email, &conn).await || CONFIG.is_signup_allowed(&email) {
User::new(email.clone()) User::new(email.clone())
} else { } else {
err!("Registration not allowed or user already exists") err!("Registration not allowed or user already exists")
@ -111,7 +142,7 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
}; };
// Make sure we don't leave a lingering invitation. // Make sure we don't leave a lingering invitation.
Invitation::take(&email, &conn); Invitation::take(&email, &conn).await;
if let Some(client_kdf_iter) = data.KdfIterations { if let Some(client_kdf_iter) = data.KdfIterations {
user.client_kdf_iter = client_kdf_iter; user.client_kdf_iter = client_kdf_iter;
@ -123,16 +154,13 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
user.set_password(&data.MasterPasswordHash, None); user.set_password(&data.MasterPasswordHash, None);
user.akey = data.Key; user.akey = data.Key;
user.password_hint = password_hint;
// Add extra fields if present // Add extra fields if present
if let Some(name) = data.Name { if let Some(name) = data.Name {
user.name = name; user.name = name;
} }
if let Some(hint) = data.MasterPasswordHint {
user.password_hint = Some(hint);
}
if let Some(keys) = data.Keys { if let Some(keys) = data.Keys {
user.private_key = Some(keys.EncryptedPrivateKey); user.private_key = Some(keys.EncryptedPrivateKey);
user.public_key = Some(keys.PublicKey); user.public_key = Some(keys.PublicKey);
@ -140,22 +168,22 @@ fn register(data: JsonUpcase<RegisterData>, conn: DbConn) -> EmptyResult {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if CONFIG.signups_verify() { if CONFIG.signups_verify() {
if let Err(e) = mail::send_welcome_must_verify(&user.email, &user.uuid) { if let Err(e) = mail::send_welcome_must_verify(&user.email, &user.uuid).await {
error!("Error sending welcome email: {:#?}", e); error!("Error sending welcome email: {:#?}", e);
} }
user.last_verifying_at = Some(user.created_at); user.last_verifying_at = Some(user.created_at);
} else if let Err(e) = mail::send_welcome(&user.email) { } else if let Err(e) = mail::send_welcome(&user.email).await {
error!("Error sending welcome email: {:#?}", e); error!("Error sending welcome email: {:#?}", e);
} }
} }
user.save(&conn) user.save(&conn).await
} }
#[get("/accounts/profile")] #[get("/accounts/profile")]
fn profile(headers: Headers, conn: DbConn) -> Json<Value> { async fn profile(headers: Headers, conn: DbConn) -> Json<Value> {
Json(headers.user.to_json(&conn)) Json(headers.user.to_json(&conn).await)
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -168,28 +196,32 @@ struct ProfileData {
} }
#[put("/accounts/profile", data = "<data>")] #[put("/accounts/profile", data = "<data>")]
fn put_profile(data: JsonUpcase<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult { async fn put_profile(data: JsonUpcase<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult {
post_profile(data, headers, conn) post_profile(data, headers, conn).await
} }
#[post("/accounts/profile", data = "<data>")] #[post("/accounts/profile", data = "<data>")]
fn post_profile(data: JsonUpcase<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult { async fn post_profile(data: JsonUpcase<ProfileData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: ProfileData = data.into_inner().data; let data: ProfileData = data.into_inner().data;
let mut user = headers.user; // Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
// This also prevents issues with very long usernames causing to large JWT's. See #2419
if data.Name.len() > 50 {
err!("The field Name must be a string with a maximum length of 50.");
}
let mut user = headers.user;
user.name = data.Name; user.name = data.Name;
user.password_hint = match data.MasterPasswordHint { user.password_hint = clean_password_hint(&data.MasterPasswordHint);
Some(ref h) if h.is_empty() => None, enforce_password_hint_setting(&user.password_hint)?;
_ => data.MasterPasswordHint,
}; user.save(&conn).await?;
user.save(&conn)?; Ok(Json(user.to_json(&conn).await))
Ok(Json(user.to_json(&conn)))
} }
#[get("/users/<uuid>/public-key")] #[get("/users/<uuid>/public-key")]
fn get_public_keys(uuid: String, _headers: Headers, conn: DbConn) -> JsonResult { async fn get_public_keys(uuid: String, _headers: Headers, conn: DbConn) -> JsonResult {
let user = match User::find_by_uuid(&uuid, &conn) { let user = match User::find_by_uuid(&uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("User doesn't exist"), None => err!("User doesn't exist"),
}; };
@ -202,7 +234,7 @@ fn get_public_keys(uuid: String, _headers: Headers, conn: DbConn) -> JsonResult
} }
#[post("/accounts/keys", data = "<data>")] #[post("/accounts/keys", data = "<data>")]
fn post_keys(data: JsonUpcase<KeysData>, headers: Headers, conn: DbConn) -> JsonResult { async fn post_keys(data: JsonUpcase<KeysData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: KeysData = data.into_inner().data; let data: KeysData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -210,7 +242,7 @@ fn post_keys(data: JsonUpcase<KeysData>, headers: Headers, conn: DbConn) -> Json
user.private_key = Some(data.EncryptedPrivateKey); user.private_key = Some(data.EncryptedPrivateKey);
user.public_key = Some(data.PublicKey); user.public_key = Some(data.PublicKey);
user.save(&conn)?; user.save(&conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"PrivateKey": user.private_key, "PrivateKey": user.private_key,
@ -228,7 +260,7 @@ struct ChangePassData {
} }
#[post("/accounts/password", data = "<data>")] #[post("/accounts/password", data = "<data>")]
fn post_password(data: JsonUpcase<ChangePassData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_password(data: JsonUpcase<ChangePassData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: ChangePassData = data.into_inner().data; let data: ChangePassData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -241,7 +273,7 @@ fn post_password(data: JsonUpcase<ChangePassData>, headers: Headers, conn: DbCon
Some(vec![String::from("post_rotatekey"), String::from("get_contacts"), String::from("get_public_keys")]), Some(vec![String::from("post_rotatekey"), String::from("get_contacts"), String::from("get_public_keys")]),
); );
user.akey = data.Key; user.akey = data.Key;
user.save(&conn) user.save(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -256,7 +288,7 @@ struct ChangeKdfData {
} }
#[post("/accounts/kdf", data = "<data>")] #[post("/accounts/kdf", data = "<data>")]
fn post_kdf(data: JsonUpcase<ChangeKdfData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_kdf(data: JsonUpcase<ChangeKdfData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: ChangeKdfData = data.into_inner().data; let data: ChangeKdfData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -268,7 +300,7 @@ fn post_kdf(data: JsonUpcase<ChangeKdfData>, headers: Headers, conn: DbConn) ->
user.client_kdf_type = data.Kdf; user.client_kdf_type = data.Kdf;
user.set_password(&data.NewMasterPasswordHash, None); user.set_password(&data.NewMasterPasswordHash, None);
user.akey = data.Key; user.akey = data.Key;
user.save(&conn) user.save(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -291,7 +323,7 @@ struct KeyData {
} }
#[post("/accounts/key", data = "<data>")] #[post("/accounts/key", data = "<data>")]
fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt: Notify) -> EmptyResult { async fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let data: KeyData = data.into_inner().data; let data: KeyData = data.into_inner().data;
if !headers.user.check_valid_password(&data.MasterPasswordHash) { if !headers.user.check_valid_password(&data.MasterPasswordHash) {
@ -302,7 +334,7 @@ fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt:
// Update folder data // Update folder data
for folder_data in data.Folders { for folder_data in data.Folders {
let mut saved_folder = match Folder::find_by_uuid(&folder_data.Id, &conn) { let mut saved_folder = match Folder::find_by_uuid(&folder_data.Id, &conn).await {
Some(folder) => folder, Some(folder) => folder,
None => err!("Folder doesn't exist"), None => err!("Folder doesn't exist"),
}; };
@ -312,14 +344,14 @@ fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt:
} }
saved_folder.name = folder_data.Name; saved_folder.name = folder_data.Name;
saved_folder.save(&conn)? saved_folder.save(&conn).await?
} }
// Update cipher data // Update cipher data
use super::ciphers::update_cipher_from_data; use super::ciphers::update_cipher_from_data;
for cipher_data in data.Ciphers { for cipher_data in data.Ciphers {
let mut saved_cipher = match Cipher::find_by_uuid(cipher_data.Id.as_ref().unwrap(), &conn) { let mut saved_cipher = match Cipher::find_by_uuid(cipher_data.Id.as_ref().unwrap(), &conn).await {
Some(cipher) => cipher, Some(cipher) => cipher,
None => err!("Cipher doesn't exist"), None => err!("Cipher doesn't exist"),
}; };
@ -330,7 +362,7 @@ fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt:
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None // Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues. // The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
update_cipher_from_data(&mut saved_cipher, cipher_data, &headers, false, &conn, &nt, UpdateType::None)? update_cipher_from_data(&mut saved_cipher, cipher_data, &headers, false, &conn, &nt, UpdateType::None).await?
} }
// Update user data // Update user data
@ -340,11 +372,11 @@ fn post_rotatekey(data: JsonUpcase<KeyData>, headers: Headers, conn: DbConn, nt:
user.private_key = Some(data.PrivateKey); user.private_key = Some(data.PrivateKey);
user.reset_security_stamp(); user.reset_security_stamp();
user.save(&conn) user.save(&conn).await
} }
#[post("/accounts/security-stamp", data = "<data>")] #[post("/accounts/security-stamp", data = "<data>")]
fn post_sstamp(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_sstamp(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: PasswordData = data.into_inner().data; let data: PasswordData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -352,9 +384,9 @@ fn post_sstamp(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -
err!("Invalid password") err!("Invalid password")
} }
Device::delete_all_by_user(&user.uuid, &conn)?; Device::delete_all_by_user(&user.uuid, &conn).await?;
user.reset_security_stamp(); user.reset_security_stamp();
user.save(&conn) user.save(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -365,7 +397,7 @@ struct EmailTokenData {
} }
#[post("/accounts/email-token", data = "<data>")] #[post("/accounts/email-token", data = "<data>")]
fn post_email_token(data: JsonUpcase<EmailTokenData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_email_token(data: JsonUpcase<EmailTokenData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: EmailTokenData = data.into_inner().data; let data: EmailTokenData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -373,7 +405,7 @@ fn post_email_token(data: JsonUpcase<EmailTokenData>, headers: Headers, conn: Db
err!("Invalid password") err!("Invalid password")
} }
if User::find_by_mail(&data.NewEmail, &conn).is_some() { if User::find_by_mail(&data.NewEmail, &conn).await.is_some() {
err!("Email already in use"); err!("Email already in use");
} }
@ -381,17 +413,17 @@ fn post_email_token(data: JsonUpcase<EmailTokenData>, headers: Headers, conn: Db
err!("Email domain not allowed"); err!("Email domain not allowed");
} }
let token = crypto::generate_token(6)?; let token = crypto::generate_email_token(6);
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Err(e) = mail::send_change_email(&data.NewEmail, &token) { if let Err(e) = mail::send_change_email(&data.NewEmail, &token).await {
error!("Error sending change-email email: {:#?}", e); error!("Error sending change-email email: {:#?}", e);
} }
} }
user.email_new = Some(data.NewEmail); user.email_new = Some(data.NewEmail);
user.email_new_token = Some(token); user.email_new_token = Some(token);
user.save(&conn) user.save(&conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -406,7 +438,7 @@ struct ChangeEmailData {
} }
#[post("/accounts/email", data = "<data>")] #[post("/accounts/email", data = "<data>")]
fn post_email(data: JsonUpcase<ChangeEmailData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_email(data: JsonUpcase<ChangeEmailData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: ChangeEmailData = data.into_inner().data; let data: ChangeEmailData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -414,7 +446,7 @@ fn post_email(data: JsonUpcase<ChangeEmailData>, headers: Headers, conn: DbConn)
err!("Invalid password") err!("Invalid password")
} }
if User::find_by_mail(&data.NewEmail, &conn).is_some() { if User::find_by_mail(&data.NewEmail, &conn).await.is_some() {
err!("Email already in use"); err!("Email already in use");
} }
@ -449,18 +481,18 @@ fn post_email(data: JsonUpcase<ChangeEmailData>, headers: Headers, conn: DbConn)
user.set_password(&data.NewMasterPasswordHash, None); user.set_password(&data.NewMasterPasswordHash, None);
user.akey = data.Key; user.akey = data.Key;
user.save(&conn) user.save(&conn).await
} }
#[post("/accounts/verify-email")] #[post("/accounts/verify-email")]
fn post_verify_email(headers: Headers) -> EmptyResult { async fn post_verify_email(headers: Headers) -> EmptyResult {
let user = headers.user; let user = headers.user;
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
err!("Cannot verify email address"); err!("Cannot verify email address");
} }
if let Err(e) = mail::send_verify_email(&user.email, &user.uuid) { if let Err(e) = mail::send_verify_email(&user.email, &user.uuid).await {
error!("Error sending verify_email email: {:#?}", e); error!("Error sending verify_email email: {:#?}", e);
} }
@ -475,10 +507,10 @@ struct VerifyEmailTokenData {
} }
#[post("/accounts/verify-email-token", data = "<data>")] #[post("/accounts/verify-email-token", data = "<data>")]
fn post_verify_email_token(data: JsonUpcase<VerifyEmailTokenData>, conn: DbConn) -> EmptyResult { async fn post_verify_email_token(data: JsonUpcase<VerifyEmailTokenData>, conn: DbConn) -> EmptyResult {
let data: VerifyEmailTokenData = data.into_inner().data; let data: VerifyEmailTokenData = data.into_inner().data;
let mut user = match User::find_by_uuid(&data.UserId, &conn) { let mut user = match User::find_by_uuid(&data.UserId, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("User doesn't exist"), None => err!("User doesn't exist"),
}; };
@ -493,7 +525,7 @@ fn post_verify_email_token(data: JsonUpcase<VerifyEmailTokenData>, conn: DbConn)
user.verified_at = Some(Utc::now().naive_utc()); user.verified_at = Some(Utc::now().naive_utc());
user.last_verifying_at = None; user.last_verifying_at = None;
user.login_verify_count = 0; user.login_verify_count = 0;
if let Err(e) = user.save(&conn) { if let Err(e) = user.save(&conn).await {
error!("Error saving email verification: {:#?}", e); error!("Error saving email verification: {:#?}", e);
} }
@ -507,14 +539,12 @@ struct DeleteRecoverData {
} }
#[post("/accounts/delete-recover", data = "<data>")] #[post("/accounts/delete-recover", data = "<data>")]
fn post_delete_recover(data: JsonUpcase<DeleteRecoverData>, conn: DbConn) -> EmptyResult { async fn post_delete_recover(data: JsonUpcase<DeleteRecoverData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverData = data.into_inner().data; let data: DeleteRecoverData = data.into_inner().data;
let user = User::find_by_mail(&data.Email, &conn);
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
if let Some(user) = user { if let Some(user) = User::find_by_mail(&data.Email, &conn).await {
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid) { if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await {
error!("Error sending delete account email: {:#?}", e); error!("Error sending delete account email: {:#?}", e);
} }
} }
@ -536,10 +566,10 @@ struct DeleteRecoverTokenData {
} }
#[post("/accounts/delete-recover-token", data = "<data>")] #[post("/accounts/delete-recover-token", data = "<data>")]
fn post_delete_recover_token(data: JsonUpcase<DeleteRecoverTokenData>, conn: DbConn) -> EmptyResult { async fn post_delete_recover_token(data: JsonUpcase<DeleteRecoverTokenData>, conn: DbConn) -> EmptyResult {
let data: DeleteRecoverTokenData = data.into_inner().data; let data: DeleteRecoverTokenData = data.into_inner().data;
let user = match User::find_by_uuid(&data.UserId, &conn) { let user = match User::find_by_uuid(&data.UserId, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("User doesn't exist"), None => err!("User doesn't exist"),
}; };
@ -551,16 +581,16 @@ fn post_delete_recover_token(data: JsonUpcase<DeleteRecoverTokenData>, conn: DbC
if claims.sub != user.uuid { if claims.sub != user.uuid {
err!("Invalid claim"); err!("Invalid claim");
} }
user.delete(&conn) user.delete(&conn).await
} }
#[post("/accounts/delete", data = "<data>")] #[post("/accounts/delete", data = "<data>")]
fn post_delete_account(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_delete_account(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult {
delete_account(data, headers, conn) delete_account(data, headers, conn).await
} }
#[delete("/accounts", data = "<data>")] #[delete("/accounts", data = "<data>")]
fn delete_account(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn delete_account(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: PasswordData = data.into_inner().data; let data: PasswordData = data.into_inner().data;
let user = headers.user; let user = headers.user;
@ -568,7 +598,7 @@ fn delete_account(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn
err!("Invalid password") err!("Invalid password")
} }
user.delete(&conn) user.delete(&conn).await
} }
#[get("/accounts/revision-date")] #[get("/accounts/revision-date")]
@ -584,7 +614,7 @@ struct PasswordHintData {
} }
#[post("/accounts/password-hint", data = "<data>")] #[post("/accounts/password-hint", data = "<data>")]
fn password_hint(data: JsonUpcase<PasswordHintData>, conn: DbConn) -> EmptyResult { async fn password_hint(data: JsonUpcase<PasswordHintData>, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() && !CONFIG.show_password_hint() { if !CONFIG.mail_enabled() && !CONFIG.show_password_hint() {
err!("This server is not configured to provide password hints."); err!("This server is not configured to provide password hints.");
} }
@ -594,19 +624,18 @@ fn password_hint(data: JsonUpcase<PasswordHintData>, conn: DbConn) -> EmptyResul
let data: PasswordHintData = data.into_inner().data; let data: PasswordHintData = data.into_inner().data;
let email = &data.Email; let email = &data.Email;
match User::find_by_mail(email, &conn) { match User::find_by_mail(email, &conn).await {
None => { None => {
// To prevent user enumeration, act as if the user exists. // To prevent user enumeration, act as if the user exists.
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// There is still a timing side channel here in that the code // There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that // paths that send mail take noticeably longer than ones that
// don't. Add a randomized sleep to mitigate this somewhat. // don't. Add a randomized sleep to mitigate this somewhat.
use rand::{thread_rng, Rng}; use rand::{rngs::SmallRng, Rng, SeedableRng};
let mut rng = thread_rng(); let mut rng = SmallRng::from_entropy();
let base = 1000;
let delta: i32 = 100; let delta: i32 = 100;
let sleep_ms = (base + rng.gen_range(-delta..=delta)) as u64; let sleep_ms = (1_000 + rng.gen_range(-delta..=delta)) as u64;
std::thread::sleep(std::time::Duration::from_millis(sleep_ms)); tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
Ok(()) Ok(())
} else { } else {
err!(NO_HINT); err!(NO_HINT);
@ -615,7 +644,7 @@ fn password_hint(data: JsonUpcase<PasswordHintData>, conn: DbConn) -> EmptyResul
Some(user) => { Some(user) => {
let hint: Option<String> = user.password_hint; let hint: Option<String> = user.password_hint;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_password_hint(email, hint)?; mail::send_password_hint(email, hint).await?;
Ok(()) Ok(())
} else if let Some(hint) = hint { } else if let Some(hint) = hint {
err!(format!("Your password hint is: {}", hint)); err!(format!("Your password hint is: {}", hint));
@ -628,15 +657,19 @@ fn password_hint(data: JsonUpcase<PasswordHintData>, conn: DbConn) -> EmptyResul
#[derive(Deserialize)] #[derive(Deserialize)]
#[allow(non_snake_case)] #[allow(non_snake_case)]
struct PreloginData { pub struct PreloginData {
Email: String, Email: String,
} }
#[post("/accounts/prelogin", data = "<data>")] #[post("/accounts/prelogin", data = "<data>")]
fn prelogin(data: JsonUpcase<PreloginData>, conn: DbConn) -> Json<Value> { async fn prelogin(data: JsonUpcase<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
}
pub async fn _prelogin(data: JsonUpcase<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner().data; let data: PreloginData = data.into_inner().data;
let (kdf_type, kdf_iter) = match User::find_by_mail(&data.Email, &conn) { let (kdf_type, kdf_iter) = match User::find_by_mail(&data.Email, &conn).await {
Some(user) => (user.client_kdf_type, user.client_kdf_iter), Some(user) => (user.client_kdf_type, user.client_kdf_iter),
None => (User::CLIENT_KDF_TYPE_DEFAULT, User::CLIENT_KDF_ITER_DEFAULT), None => (User::CLIENT_KDF_TYPE_DEFAULT, User::CLIENT_KDF_ITER_DEFAULT),
}; };
@ -666,7 +699,12 @@ fn verify_password(data: JsonUpcase<SecretVerificationRequest>, headers: Headers
Ok(()) Ok(())
} }
fn _api_key(data: JsonUpcase<SecretVerificationRequest>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult { async fn _api_key(
data: JsonUpcase<SecretVerificationRequest>,
rotate: bool,
headers: Headers,
conn: DbConn,
) -> JsonResult {
let data: SecretVerificationRequest = data.into_inner().data; let data: SecretVerificationRequest = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -676,7 +714,7 @@ fn _api_key(data: JsonUpcase<SecretVerificationRequest>, rotate: bool, headers:
if rotate || user.api_key.is_none() { if rotate || user.api_key.is_none() {
user.api_key = Some(crypto::generate_api_key()); user.api_key = Some(crypto::generate_api_key());
user.save(&conn).expect("Error saving API key"); user.save(&conn).await.expect("Error saving API key");
} }
Ok(Json(json!({ Ok(Json(json!({
@ -686,11 +724,11 @@ fn _api_key(data: JsonUpcase<SecretVerificationRequest>, rotate: bool, headers:
} }
#[post("/accounts/api-key", data = "<data>")] #[post("/accounts/api-key", data = "<data>")]
fn api_key(data: JsonUpcase<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult { async fn api_key(data: JsonUpcase<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, false, headers, conn) _api_key(data, false, headers, conn).await
} }
#[post("/accounts/rotate-api-key", data = "<data>")] #[post("/accounts/rotate-api-key", data = "<data>")]
fn rotate_api_key(data: JsonUpcase<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult { async fn rotate_api_key(data: JsonUpcase<SecretVerificationRequest>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, true, headers, conn) _api_key(data, true, headers, conn).await
} }

845
src/api/core/ciphers.rs

File diff suppressed because it is too large

257
src/api/core/emergency_access.rs

@ -1,16 +1,21 @@
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use std::borrow::Borrow; use std::borrow::Borrow;
use crate::{ use crate::{
api::{EmptyResult, JsonResult, JsonUpcase, NumberOrString}, api::{
core::{CipherSyncData, CipherSyncType},
EmptyResult, JsonResult, JsonUpcase, NumberOrString,
},
auth::{decode_emergency_access_invite, Headers}, auth::{decode_emergency_access_invite, Headers},
db::{models::*, DbConn, DbPool}, db::{models::*, DbConn, DbPool},
mail, CONFIG, mail, CONFIG,
}; };
use futures::{stream, stream::StreamExt};
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
routes![ routes![
get_contacts, get_contacts,
@ -36,13 +41,17 @@ pub fn routes() -> Vec<Route> {
// region get // region get
#[get("/emergency-access/trusted")] #[get("/emergency-access/trusted")]
fn get_contacts(headers: Headers, conn: DbConn) -> JsonResult { async fn get_contacts(headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let emergency_access_list = EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &conn); let emergency_access_list_json =
stream::iter(EmergencyAccess::find_all_by_grantor_uuid(&headers.user.uuid, &conn).await)
let emergency_access_list_json: Vec<Value> = .then(|e| async {
emergency_access_list.iter().map(|e| e.to_json_grantee_details(&conn)).collect(); let e = e; // Move out this single variable
e.to_json_grantee_details(&conn).await
})
.collect::<Vec<Value>>()
.await;
Ok(Json(json!({ Ok(Json(json!({
"Data": emergency_access_list_json, "Data": emergency_access_list_json,
@ -52,13 +61,17 @@ fn get_contacts(headers: Headers, conn: DbConn) -> JsonResult {
} }
#[get("/emergency-access/granted")] #[get("/emergency-access/granted")]
fn get_grantees(headers: Headers, conn: DbConn) -> JsonResult { async fn get_grantees(headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let emergency_access_list = EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &conn); let emergency_access_list_json =
stream::iter(EmergencyAccess::find_all_by_grantee_uuid(&headers.user.uuid, &conn).await)
let emergency_access_list_json: Vec<Value> = .then(|e| async {
emergency_access_list.iter().map(|e| e.to_json_grantor_details(&conn)).collect(); let e = e; // Move out this single variable
e.to_json_grantor_details(&conn).await
})
.collect::<Vec<Value>>()
.await;
Ok(Json(json!({ Ok(Json(json!({
"Data": emergency_access_list_json, "Data": emergency_access_list_json,
@ -68,11 +81,11 @@ fn get_grantees(headers: Headers, conn: DbConn) -> JsonResult {
} }
#[get("/emergency-access/<emer_id>")] #[get("/emergency-access/<emer_id>")]
fn get_emergency_access(emer_id: String, conn: DbConn) -> JsonResult { async fn get_emergency_access(emer_id: String, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
match EmergencyAccess::find_by_uuid(&emer_id, &conn) { match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emergency_access) => Ok(Json(emergency_access.to_json_grantee_details(&conn))), Some(emergency_access) => Ok(Json(emergency_access.to_json_grantee_details(&conn).await)),
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
} }
} }
@ -90,17 +103,25 @@ struct EmergencyAccessUpdateData {
} }
#[put("/emergency-access/<emer_id>", data = "<data>")] #[put("/emergency-access/<emer_id>", data = "<data>")]
fn put_emergency_access(emer_id: String, data: JsonUpcase<EmergencyAccessUpdateData>, conn: DbConn) -> JsonResult { async fn put_emergency_access(
post_emergency_access(emer_id, data, conn) emer_id: String,
data: JsonUpcase<EmergencyAccessUpdateData>,
conn: DbConn,
) -> JsonResult {
post_emergency_access(emer_id, data, conn).await
} }
#[post("/emergency-access/<emer_id>", data = "<data>")] #[post("/emergency-access/<emer_id>", data = "<data>")]
fn post_emergency_access(emer_id: String, data: JsonUpcase<EmergencyAccessUpdateData>, conn: DbConn) -> JsonResult { async fn post_emergency_access(
emer_id: String,
data: JsonUpcase<EmergencyAccessUpdateData>,
conn: DbConn,
) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let data: EmergencyAccessUpdateData = data.into_inner().data; let data: EmergencyAccessUpdateData = data.into_inner().data;
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emergency_access) => emergency_access, Some(emergency_access) => emergency_access,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -114,7 +135,7 @@ fn post_emergency_access(emer_id: String, data: JsonUpcase<EmergencyAccessUpdate
emergency_access.wait_time_days = data.WaitTimeDays; emergency_access.wait_time_days = data.WaitTimeDays;
emergency_access.key_encrypted = data.KeyEncrypted; emergency_access.key_encrypted = data.KeyEncrypted;
emergency_access.save(&conn)?; emergency_access.save(&conn).await?;
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} }
@ -123,12 +144,12 @@ fn post_emergency_access(emer_id: String, data: JsonUpcase<EmergencyAccessUpdate
// region delete // region delete
#[delete("/emergency-access/<emer_id>")] #[delete("/emergency-access/<emer_id>")]
fn delete_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult { async fn delete_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let grantor_user = headers.user; let grantor_user = headers.user;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => { Some(emer) => {
if emer.grantor_uuid != grantor_user.uuid && emer.grantee_uuid != Some(grantor_user.uuid) { if emer.grantor_uuid != grantor_user.uuid && emer.grantee_uuid != Some(grantor_user.uuid) {
err!("Emergency access not valid.") err!("Emergency access not valid.")
@ -137,13 +158,13 @@ fn delete_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> E
} }
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
emergency_access.delete(&conn)?; emergency_access.delete(&conn).await?;
Ok(()) Ok(())
} }
#[post("/emergency-access/<emer_id>/delete")] #[post("/emergency-access/<emer_id>/delete")]
fn post_delete_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult { async fn post_delete_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult {
delete_emergency_access(emer_id, headers, conn) delete_emergency_access(emer_id, headers, conn).await
} }
// endregion // endregion
@ -159,7 +180,7 @@ struct EmergencyAccessInviteData {
} }
#[post("/emergency-access/invite", data = "<data>")] #[post("/emergency-access/invite", data = "<data>")]
fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let data: EmergencyAccessInviteData = data.into_inner().data; let data: EmergencyAccessInviteData = data.into_inner().data;
@ -180,7 +201,7 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
err!("You can not set yourself as an emergency contact.") err!("You can not set yourself as an emergency contact.")
} }
let grantee_user = match User::find_by_mail(&email, &conn) { let grantee_user = match User::find_by_mail(&email, &conn).await {
None => { None => {
if !CONFIG.invitations_allowed() { if !CONFIG.invitations_allowed() {
err!(format!("Grantee user does not exist: {}", email)) err!(format!("Grantee user does not exist: {}", email))
@ -192,11 +213,11 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
if !CONFIG.mail_enabled() { if !CONFIG.mail_enabled() {
let invitation = Invitation::new(email.clone()); let invitation = Invitation::new(email.clone());
invitation.save(&conn)?; invitation.save(&conn).await?;
} }
let mut user = User::new(email.clone()); let mut user = User::new(email.clone());
user.save(&conn)?; user.save(&conn).await?;
user user
} }
Some(user) => user, Some(user) => user,
@ -208,6 +229,7 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
&grantee_user.email, &grantee_user.email,
&conn, &conn,
) )
.await
.is_some() .is_some()
{ {
err!(format!("Grantee user already invited: {}", email)) err!(format!("Grantee user already invited: {}", email))
@ -220,7 +242,7 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
new_type, new_type,
wait_time_days, wait_time_days,
); );
new_emergency_access.save(&conn)?; new_emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite( mail::send_emergency_access_invite(
@ -229,12 +251,13 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
Some(new_emergency_access.uuid), Some(new_emergency_access.uuid),
Some(grantor_user.name.clone()), Some(grantor_user.name.clone()),
Some(grantor_user.email), Some(grantor_user.email),
)?; )
.await?;
} else { } else {
// Automatically mark user as accepted if no email invites // Automatically mark user as accepted if no email invites
match User::find_by_mail(&email, &conn) { match User::find_by_mail(&email, &conn).await {
Some(user) => { Some(user) => {
match accept_invite_process(user.uuid, new_emergency_access.uuid, Some(email), conn.borrow()) { match accept_invite_process(user.uuid, new_emergency_access.uuid, Some(email), conn.borrow()).await {
Ok(v) => (v), Ok(v) => (v),
Err(e) => err!(e.to_string()), Err(e) => err!(e.to_string()),
} }
@ -247,10 +270,10 @@ fn send_invite(data: JsonUpcase<EmergencyAccessInviteData>, headers: Headers, co
} }
#[post("/emergency-access/<emer_id>/reinvite")] #[post("/emergency-access/<emer_id>/reinvite")]
fn resend_invite(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult { async fn resend_invite(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -268,7 +291,7 @@ fn resend_invite(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult
None => err!("Email not valid."), None => err!("Email not valid."),
}; };
let grantee_user = match User::find_by_mail(&email, &conn) { let grantee_user = match User::find_by_mail(&email, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantee user not found."), None => err!("Grantee user not found."),
}; };
@ -282,15 +305,18 @@ fn resend_invite(emer_id: String, headers: Headers, conn: DbConn) -> EmptyResult
Some(emergency_access.uuid), Some(emergency_access.uuid),
Some(grantor_user.name.clone()), Some(grantor_user.name.clone()),
Some(grantor_user.email), Some(grantor_user.email),
)?; )
.await?;
} else { } else {
if Invitation::find_by_mail(&email, &conn).is_none() { if Invitation::find_by_mail(&email, &conn).await.is_none() {
let invitation = Invitation::new(email); let invitation = Invitation::new(email);
invitation.save(&conn)?; invitation.save(&conn).await?;
} }
// Automatically mark user as accepted if no email invites // Automatically mark user as accepted if no email invites
match accept_invite_process(grantee_user.uuid, emergency_access.uuid, emergency_access.email, conn.borrow()) { match accept_invite_process(grantee_user.uuid, emergency_access.uuid, emergency_access.email, conn.borrow())
.await
{
Ok(v) => (v), Ok(v) => (v),
Err(e) => err!(e.to_string()), Err(e) => err!(e.to_string()),
} }
@ -306,28 +332,28 @@ struct AcceptData {
} }
#[post("/emergency-access/<emer_id>/accept", data = "<data>")] #[post("/emergency-access/<emer_id>/accept", data = "<data>")]
fn accept_invite(emer_id: String, data: JsonUpcase<AcceptData>, conn: DbConn) -> EmptyResult { async fn accept_invite(emer_id: String, data: JsonUpcase<AcceptData>, conn: DbConn) -> EmptyResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let data: AcceptData = data.into_inner().data; let data: AcceptData = data.into_inner().data;
let token = &data.Token; let token = &data.Token;
let claims = decode_emergency_access_invite(token)?; let claims = decode_emergency_access_invite(token)?;
let grantee_user = match User::find_by_mail(&claims.email, &conn) { let grantee_user = match User::find_by_mail(&claims.email, &conn).await {
Some(user) => { Some(user) => {
Invitation::take(&claims.email, &conn); Invitation::take(&claims.email, &conn).await;
user user
} }
None => err!("Invited user not found"), None => err!("Invited user not found"),
}; };
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
// get grantor user to send Accepted email // get grantor user to send Accepted email
let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn) { let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
@ -336,13 +362,13 @@ fn accept_invite(emer_id: String, data: JsonUpcase<AcceptData>, conn: DbConn) ->
&& (claims.grantor_name.is_some() && grantor_user.name == claims.grantor_name.unwrap()) && (claims.grantor_name.is_some() && grantor_user.name == claims.grantor_name.unwrap())
&& (claims.grantor_email.is_some() && grantor_user.email == claims.grantor_email.unwrap()) && (claims.grantor_email.is_some() && grantor_user.email == claims.grantor_email.unwrap())
{ {
match accept_invite_process(grantee_user.uuid.clone(), emer_id, Some(grantee_user.email.clone()), &conn) { match accept_invite_process(grantee_user.uuid.clone(), emer_id, Some(grantee_user.email.clone()), &conn).await {
Ok(v) => (v), Ok(v) => (v),
Err(e) => err!(e.to_string()), Err(e) => err!(e.to_string()),
} }
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_accepted(&grantor_user.email, &grantee_user.email)?; mail::send_emergency_access_invite_accepted(&grantor_user.email, &grantee_user.email).await?;
} }
Ok(()) Ok(())
@ -351,8 +377,13 @@ fn accept_invite(emer_id: String, data: JsonUpcase<AcceptData>, conn: DbConn) ->
} }
} }
fn accept_invite_process(grantee_uuid: String, emer_id: String, email: Option<String>, conn: &DbConn) -> EmptyResult { async fn accept_invite_process(
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, conn) { grantee_uuid: String,
emer_id: String,
email: Option<String>,
conn: &DbConn,
) -> EmptyResult {
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -369,7 +400,7 @@ fn accept_invite_process(grantee_uuid: String, emer_id: String, email: Option<St
emergency_access.status = EmergencyAccessStatus::Accepted as i32; emergency_access.status = EmergencyAccessStatus::Accepted as i32;
emergency_access.grantee_uuid = Some(grantee_uuid); emergency_access.grantee_uuid = Some(grantee_uuid);
emergency_access.email = None; emergency_access.email = None;
emergency_access.save(conn) emergency_access.save(conn).await
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -379,7 +410,7 @@ struct ConfirmData {
} }
#[post("/emergency-access/<emer_id>/confirm", data = "<data>")] #[post("/emergency-access/<emer_id>/confirm", data = "<data>")]
fn confirm_emergency_access( async fn confirm_emergency_access(
emer_id: String, emer_id: String,
data: JsonUpcase<ConfirmData>, data: JsonUpcase<ConfirmData>,
headers: Headers, headers: Headers,
@ -391,7 +422,7 @@ fn confirm_emergency_access(
let data: ConfirmData = data.into_inner().data; let data: ConfirmData = data.into_inner().data;
let key = data.Key; let key = data.Key;
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -402,13 +433,13 @@ fn confirm_emergency_access(
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&confirming_user.uuid, &conn) { let grantor_user = match User::find_by_uuid(&confirming_user.uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let grantee_user = match User::find_by_uuid(grantee_uuid, &conn) { let grantee_user = match User::find_by_uuid(grantee_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantee user not found."), None => err!("Grantee user not found."),
}; };
@ -417,10 +448,10 @@ fn confirm_emergency_access(
emergency_access.key_encrypted = Some(key); emergency_access.key_encrypted = Some(key);
emergency_access.email = None; emergency_access.email = None;
emergency_access.save(&conn)?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_invite_confirmed(&grantee_user.email, &grantor_user.name)?; mail::send_emergency_access_invite_confirmed(&grantee_user.email, &grantor_user.name).await?;
} }
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} else { } else {
@ -433,11 +464,11 @@ fn confirm_emergency_access(
// region access emergency access // region access emergency access
#[post("/emergency-access/<emer_id>/initiate")] #[post("/emergency-access/<emer_id>/initiate")]
fn initiate_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn initiate_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let initiating_user = headers.user; let initiating_user = headers.user;
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -448,7 +479,7 @@ fn initiate_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn) { let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
@ -458,7 +489,7 @@ fn initiate_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
emergency_access.updated_at = now; emergency_access.updated_at = now;
emergency_access.recovery_initiated_at = Some(now); emergency_access.recovery_initiated_at = Some(now);
emergency_access.last_notification_at = Some(now); emergency_access.last_notification_at = Some(now);
emergency_access.save(&conn)?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_initiated( mail::send_emergency_access_recovery_initiated(
@ -466,17 +497,18 @@ fn initiate_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
&initiating_user.name, &initiating_user.name,
emergency_access.get_type_as_str(), emergency_access.get_type_as_str(),
&emergency_access.wait_time_days.clone().to_string(), &emergency_access.wait_time_days.clone().to_string(),
)?; )
.await?;
} }
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} }
#[post("/emergency-access/<emer_id>/approve")] #[post("/emergency-access/<emer_id>/approve")]
fn approve_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn approve_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let approving_user = headers.user; let approving_user = headers.user;
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -487,22 +519,22 @@ fn approve_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&approving_user.uuid, &conn) { let grantor_user = match User::find_by_uuid(&approving_user.uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let grantee_user = match User::find_by_uuid(grantee_uuid, &conn) { let grantee_user = match User::find_by_uuid(grantee_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantee user not found."), None => err!("Grantee user not found."),
}; };
emergency_access.status = EmergencyAccessStatus::RecoveryApproved as i32; emergency_access.status = EmergencyAccessStatus::RecoveryApproved as i32;
emergency_access.save(&conn)?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name)?; mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name).await?;
} }
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} else { } else {
@ -511,11 +543,11 @@ fn approve_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
} }
#[post("/emergency-access/<emer_id>/reject")] #[post("/emergency-access/<emer_id>/reject")]
fn reject_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn reject_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let rejecting_user = headers.user; let rejecting_user = headers.user;
let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let mut emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -527,22 +559,22 @@ fn reject_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> J
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&rejecting_user.uuid, &conn) { let grantor_user = match User::find_by_uuid(&rejecting_user.uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() { if let Some(grantee_uuid) = emergency_access.grantee_uuid.as_ref() {
let grantee_user = match User::find_by_uuid(grantee_uuid, &conn) { let grantee_user = match User::find_by_uuid(grantee_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantee user not found."), None => err!("Grantee user not found."),
}; };
emergency_access.status = EmergencyAccessStatus::Confirmed as i32; emergency_access.status = EmergencyAccessStatus::Confirmed as i32;
emergency_access.save(&conn)?; emergency_access.save(&conn).await?;
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
mail::send_emergency_access_recovery_rejected(&grantee_user.email, &grantor_user.name)?; mail::send_emergency_access_recovery_rejected(&grantee_user.email, &grantor_user.name).await?;
} }
Ok(Json(emergency_access.to_json())) Ok(Json(emergency_access.to_json()))
} else { } else {
@ -555,12 +587,12 @@ fn reject_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> J
// region action // region action
#[post("/emergency-access/<emer_id>/view")] #[post("/emergency-access/<emer_id>/view")]
fn view_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn view_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let requesting_user = headers.user; let requesting_user = headers.user;
let host = headers.host; let host = headers.host;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -569,10 +601,17 @@ fn view_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> Jso
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &conn); let ciphers = Cipher::find_owned_by_user(&emergency_access.grantor_uuid, &conn).await;
let cipher_sync_data =
CipherSyncData::new(&emergency_access.grantor_uuid, &ciphers, CipherSyncType::User, &conn).await;
let ciphers_json: Vec<Value> = let ciphers_json = stream::iter(ciphers)
ciphers.iter().map(|c| c.to_json(&host, &emergency_access.grantor_uuid, &conn)).collect(); .then(|c| async {
let c = c; // Move out this single variable
c.to_json(&host, &emergency_access.grantor_uuid, Some(&cipher_sync_data), &conn).await
})
.collect::<Vec<Value>>()
.await;
Ok(Json(json!({ Ok(Json(json!({
"Ciphers": ciphers_json, "Ciphers": ciphers_json,
@ -582,11 +621,11 @@ fn view_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> Jso
} }
#[post("/emergency-access/<emer_id>/takeover")] #[post("/emergency-access/<emer_id>/takeover")]
fn takeover_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn takeover_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_allowed()?; check_emergency_access_allowed()?;
let requesting_user = headers.user; let requesting_user = headers.user;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -595,7 +634,7 @@ fn takeover_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn) { let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
@ -616,7 +655,7 @@ struct EmergencyAccessPasswordData {
} }
#[post("/emergency-access/<emer_id>/password", data = "<data>")] #[post("/emergency-access/<emer_id>/password", data = "<data>")]
fn password_emergency_access( async fn password_emergency_access(
emer_id: String, emer_id: String,
data: JsonUpcase<EmergencyAccessPasswordData>, data: JsonUpcase<EmergencyAccessPasswordData>,
headers: Headers, headers: Headers,
@ -629,7 +668,7 @@ fn password_emergency_access(
let key = data.Key; let key = data.Key;
let requesting_user = headers.user; let requesting_user = headers.user;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -638,7 +677,7 @@ fn password_emergency_access(
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let mut grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn) { let mut grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
@ -646,18 +685,15 @@ fn password_emergency_access(
// change grantor_user password // change grantor_user password
grantor_user.set_password(new_master_password_hash, None); grantor_user.set_password(new_master_password_hash, None);
grantor_user.akey = key; grantor_user.akey = key;
grantor_user.save(&conn)?; grantor_user.save(&conn).await?;
// Disable TwoFactor providers since they will otherwise block logins // Disable TwoFactor providers since they will otherwise block logins
TwoFactor::delete_all_by_user(&grantor_user.uuid, &conn)?; TwoFactor::delete_all_by_user(&grantor_user.uuid, &conn).await?;
// Removing owner, check that there are at least another owner
let user_org_grantor = UserOrganization::find_any_state_by_user(&grantor_user.uuid, &conn);
// Remove grantor from all organisations unless Owner // Remove grantor from all organisations unless Owner
for user_org in user_org_grantor { for user_org in UserOrganization::find_any_state_by_user(&grantor_user.uuid, &conn).await {
if user_org.atype != UserOrgType::Owner as i32 { if user_org.atype != UserOrgType::Owner as i32 {
user_org.delete(&conn)?; user_org.delete(&conn).await?;
} }
} }
Ok(()) Ok(())
@ -666,9 +702,9 @@ fn password_emergency_access(
// endregion // endregion
#[get("/emergency-access/<emer_id>/policies")] #[get("/emergency-access/<emer_id>/policies")]
fn policies_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult { async fn policies_emergency_access(emer_id: String, headers: Headers, conn: DbConn) -> JsonResult {
let requesting_user = headers.user; let requesting_user = headers.user;
let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn) { let emergency_access = match EmergencyAccess::find_by_uuid(&emer_id, &conn).await {
Some(emer) => emer, Some(emer) => emer,
None => err!("Emergency access not valid."), None => err!("Emergency access not valid."),
}; };
@ -677,13 +713,13 @@ fn policies_emergency_access(emer_id: String, headers: Headers, conn: DbConn) ->
err!("Emergency access not valid.") err!("Emergency access not valid.")
} }
let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn) { let grantor_user = match User::find_by_uuid(&emergency_access.grantor_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Grantor user not found."), None => err!("Grantor user not found."),
}; };
let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &conn); let policies = OrgPolicy::find_confirmed_by_user(&grantor_user.uuid, &conn);
let policies_json: Vec<Value> = policies.iter().map(OrgPolicy::to_json).collect(); let policies_json: Vec<Value> = policies.await.iter().map(OrgPolicy::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
"Data": policies_json, "Data": policies_json,
@ -709,14 +745,14 @@ fn check_emergency_access_allowed() -> EmptyResult {
Ok(()) Ok(())
} }
pub fn emergency_request_timeout_job(pool: DbPool) { pub async fn emergency_request_timeout_job(pool: DbPool) {
debug!("Start emergency_request_timeout_job"); debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() { if !CONFIG.emergency_access_allowed() {
return; return;
} }
if let Ok(conn) = pool.get() { if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn); let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {
debug!("No emergency request timeout to approve"); debug!("No emergency request timeout to approve");
@ -725,18 +761,20 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
for mut emer in emergency_access_list { for mut emer in emergency_access_list {
if emer.recovery_initiated_at.is_some() if emer.recovery_initiated_at.is_some()
&& Utc::now().naive_utc() && Utc::now().naive_utc()
>= emer.recovery_initiated_at.unwrap() + Duration::days(emer.wait_time_days as i64) >= emer.recovery_initiated_at.unwrap() + Duration::days(i64::from(emer.wait_time_days))
{ {
emer.status = EmergencyAccessStatus::RecoveryApproved as i32; emer.status = EmergencyAccessStatus::RecoveryApproved as i32;
emer.save(&conn).expect("Cannot save emergency access on job"); emer.save(&conn).await.expect("Cannot save emergency access on job");
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// get grantor user to send Accepted email // get grantor user to send Accepted email
let grantor_user = User::find_by_uuid(&emer.grantor_uuid, &conn).expect("Grantor user not found."); let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found.");
// get grantee user to send Accepted email // get grantee user to send Accepted email
let grantee_user = let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid."), &conn) User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid."), &conn)
.await
.expect("Grantee user not found."); .expect("Grantee user not found.");
mail::send_emergency_access_recovery_timed_out( mail::send_emergency_access_recovery_timed_out(
@ -744,9 +782,11 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
&grantee_user.name.clone(), &grantee_user.name.clone(),
emer.get_type_as_str(), emer.get_type_as_str(),
) )
.await
.expect("Error on sending email"); .expect("Error on sending email");
mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name.clone()) mail::send_emergency_access_recovery_approved(&grantee_user.email, &grantor_user.name.clone())
.await
.expect("Error on sending email"); .expect("Error on sending email");
} }
} }
@ -756,14 +796,14 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
} }
} }
pub fn emergency_notification_reminder_job(pool: DbPool) { pub async fn emergency_notification_reminder_job(pool: DbPool) {
debug!("Start emergency_notification_reminder_job"); debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() { if !CONFIG.emergency_access_allowed() {
return; return;
} }
if let Ok(conn) = pool.get() { if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn); let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {
debug!("No emergency request reminder notification to send"); debug!("No emergency request reminder notification to send");
@ -772,20 +812,22 @@ pub fn emergency_notification_reminder_job(pool: DbPool) {
for mut emer in emergency_access_list { for mut emer in emergency_access_list {
if (emer.recovery_initiated_at.is_some() if (emer.recovery_initiated_at.is_some()
&& Utc::now().naive_utc() && Utc::now().naive_utc()
>= emer.recovery_initiated_at.unwrap() + Duration::days((emer.wait_time_days as i64) - 1)) >= emer.recovery_initiated_at.unwrap() + Duration::days((i64::from(emer.wait_time_days)) - 1))
&& (emer.last_notification_at.is_none() && (emer.last_notification_at.is_none()
|| (emer.last_notification_at.is_some() || (emer.last_notification_at.is_some()
&& Utc::now().naive_utc() >= emer.last_notification_at.unwrap() + Duration::days(1))) && Utc::now().naive_utc() >= emer.last_notification_at.unwrap() + Duration::days(1)))
{ {
emer.save(&conn).expect("Cannot save emergency access on job"); emer.save(&conn).await.expect("Cannot save emergency access on job");
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
// get grantor user to send Accepted email // get grantor user to send Accepted email
let grantor_user = User::find_by_uuid(&emer.grantor_uuid, &conn).expect("Grantor user not found."); let grantor_user =
User::find_by_uuid(&emer.grantor_uuid, &conn).await.expect("Grantor user not found.");
// get grantee user to send Accepted email // get grantee user to send Accepted email
let grantee_user = let grantee_user =
User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid."), &conn) User::find_by_uuid(&emer.grantee_uuid.clone().expect("Grantee user invalid."), &conn)
.await
.expect("Grantee user not found."); .expect("Grantee user not found.");
mail::send_emergency_access_recovery_reminder( mail::send_emergency_access_recovery_reminder(
@ -794,6 +836,7 @@ pub fn emergency_notification_reminder_job(pool: DbPool) {
emer.get_type_as_str(), emer.get_type_as_str(),
&emer.wait_time_days.to_string(), // TODO(jjlin): This should be the number of days left. &emer.wait_time_days.to_string(), // TODO(jjlin): This should be the number of days left.
) )
.await
.expect("Error on sending email"); .expect("Error on sending email");
} }
} }

53
src/api/core/folders.rs

@ -1,4 +1,4 @@
use rocket_contrib::json::Json; use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
@ -12,9 +12,8 @@ pub fn routes() -> Vec<rocket::Route> {
} }
#[get("/folders")] #[get("/folders")]
fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> { async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
let folders = Folder::find_by_user(&headers.user.uuid, &conn); let folders = Folder::find_by_user(&headers.user.uuid, &conn).await;
let folders_json: Vec<Value> = folders.iter().map(Folder::to_json).collect(); let folders_json: Vec<Value> = folders.iter().map(Folder::to_json).collect();
Json(json!({ Json(json!({
@ -25,8 +24,8 @@ fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
} }
#[get("/folders/<uuid>")] #[get("/folders/<uuid>")]
fn get_folder(uuid: String, headers: Headers, conn: DbConn) -> JsonResult { async fn get_folder(uuid: String, headers: Headers, conn: DbConn) -> JsonResult {
let folder = match Folder::find_by_uuid(&uuid, &conn) { let folder = match Folder::find_by_uuid(&uuid, &conn).await {
Some(folder) => folder, Some(folder) => folder,
_ => err!("Invalid folder"), _ => err!("Invalid folder"),
}; };
@ -45,27 +44,39 @@ pub struct FolderData {
} }
#[post("/folders", data = "<data>")] #[post("/folders", data = "<data>")]
fn post_folders(data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn post_folders(data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let data: FolderData = data.into_inner().data; let data: FolderData = data.into_inner().data;
let mut folder = Folder::new(headers.user.uuid, data.Name); let mut folder = Folder::new(headers.user.uuid, data.Name);
folder.save(&conn)?; folder.save(&conn).await?;
nt.send_folder_update(UpdateType::FolderCreate, &folder); nt.send_folder_update(UpdateType::FolderCreate, &folder).await;
Ok(Json(folder.to_json())) Ok(Json(folder.to_json()))
} }
#[post("/folders/<uuid>", data = "<data>")] #[post("/folders/<uuid>", data = "<data>")]
fn post_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn post_folder(
put_folder(uuid, data, headers, conn, nt) uuid: String,
data: JsonUpcase<FolderData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
put_folder(uuid, data, headers, conn, nt).await
} }
#[put("/folders/<uuid>", data = "<data>")] #[put("/folders/<uuid>", data = "<data>")]
fn put_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn put_folder(
uuid: String,
data: JsonUpcase<FolderData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let data: FolderData = data.into_inner().data; let data: FolderData = data.into_inner().data;
let mut folder = match Folder::find_by_uuid(&uuid, &conn) { let mut folder = match Folder::find_by_uuid(&uuid, &conn).await {
Some(folder) => folder, Some(folder) => folder,
_ => err!("Invalid folder"), _ => err!("Invalid folder"),
}; };
@ -76,20 +87,20 @@ fn put_folder(uuid: String, data: JsonUpcase<FolderData>, headers: Headers, conn
folder.name = data.Name; folder.name = data.Name;
folder.save(&conn)?; folder.save(&conn).await?;
nt.send_folder_update(UpdateType::FolderUpdate, &folder); nt.send_folder_update(UpdateType::FolderUpdate, &folder).await;
Ok(Json(folder.to_json())) Ok(Json(folder.to_json()))
} }
#[post("/folders/<uuid>/delete")] #[post("/folders/<uuid>/delete")]
fn delete_folder_post(uuid: String, headers: Headers, conn: DbConn, nt: Notify) -> EmptyResult { async fn delete_folder_post(uuid: String, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_folder(uuid, headers, conn, nt) delete_folder(uuid, headers, conn, nt).await
} }
#[delete("/folders/<uuid>")] #[delete("/folders/<uuid>")]
fn delete_folder(uuid: String, headers: Headers, conn: DbConn, nt: Notify) -> EmptyResult { async fn delete_folder(uuid: String, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let folder = match Folder::find_by_uuid(&uuid, &conn) { let folder = match Folder::find_by_uuid(&uuid, &conn).await {
Some(folder) => folder, Some(folder) => folder,
_ => err!("Invalid folder"), _ => err!("Invalid folder"),
}; };
@ -99,8 +110,8 @@ fn delete_folder(uuid: String, headers: Headers, conn: DbConn, nt: Notify) -> Em
} }
// Delete the actual folder entry // Delete the actual folder entry
folder.delete(&conn)?; folder.delete(&conn).await?;
nt.send_folder_update(UpdateType::FolderDelete, &folder); nt.send_folder_update(UpdateType::FolderDelete, &folder).await;
Ok(()) Ok(())
} }

46
src/api/core/mod.rs

@ -1,4 +1,4 @@
mod accounts; pub mod accounts;
mod ciphers; mod ciphers;
mod emergency_access; mod emergency_access;
mod folders; mod folders;
@ -7,13 +7,16 @@ mod sends;
pub mod two_factor; pub mod two_factor;
pub use ciphers::purge_trashed_ciphers; pub use ciphers::purge_trashed_ciphers;
pub use ciphers::{CipherSyncData, CipherSyncType};
pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job}; pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job};
pub use sends::purge_sends; pub use sends::purge_sends;
pub use two_factor::send_incomplete_2fa_notifications; pub use two_factor::send_incomplete_2fa_notifications;
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
let mut mod_routes = let mut device_token_routes = routes![clear_device_token, put_device_token];
routes![clear_device_token, put_device_token, get_eq_domains, post_eq_domains, put_eq_domains, hibp_breach,]; let mut eq_domains_routes = routes![get_eq_domains, post_eq_domains, put_eq_domains];
let mut hibp_routes = routes![hibp_breach];
let mut meta_routes = routes![alive, now, version];
let mut routes = Vec::new(); let mut routes = Vec::new();
routes.append(&mut accounts::routes()); routes.append(&mut accounts::routes());
@ -23,7 +26,10 @@ pub fn routes() -> Vec<Route> {
routes.append(&mut organizations::routes()); routes.append(&mut organizations::routes());
routes.append(&mut two_factor::routes()); routes.append(&mut two_factor::routes());
routes.append(&mut sends::routes()); routes.append(&mut sends::routes());
routes.append(&mut mod_routes); routes.append(&mut device_token_routes);
routes.append(&mut eq_domains_routes);
routes.append(&mut hibp_routes);
routes.append(&mut meta_routes);
routes routes
} }
@ -31,8 +37,8 @@ pub fn routes() -> Vec<Route> {
// //
// Move this somewhere else // Move this somewhere else
// //
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
@ -121,7 +127,7 @@ struct EquivDomainData {
} }
#[post("/settings/domains", data = "<data>")] #[post("/settings/domains", data = "<data>")]
fn post_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn) -> JsonResult { async fn post_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EquivDomainData = data.into_inner().data; let data: EquivDomainData = data.into_inner().data;
let excluded_globals = data.ExcludedGlobalEquivalentDomains.unwrap_or_default(); let excluded_globals = data.ExcludedGlobalEquivalentDomains.unwrap_or_default();
@ -133,18 +139,18 @@ fn post_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: Db
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string()); user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string()); user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string());
user.save(&conn)?; user.save(&conn).await?;
Ok(Json(json!({}))) Ok(Json(json!({})))
} }
#[put("/settings/domains", data = "<data>")] #[put("/settings/domains", data = "<data>")]
fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn) -> JsonResult { async fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbConn) -> JsonResult {
post_eq_domains(data, headers, conn) post_eq_domains(data, headers, conn).await
} }
#[get("/hibp/breach?<username>")] #[get("/hibp/breach?<username>")]
fn hibp_breach(username: String) -> JsonResult { async fn hibp_breach(username: String) -> JsonResult {
let url = format!( let url = format!(
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false", "https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
username username
@ -153,14 +159,14 @@ fn hibp_breach(username: String) -> JsonResult {
if let Some(api_key) = crate::CONFIG.hibp_api_key() { if let Some(api_key) = crate::CONFIG.hibp_api_key() {
let hibp_client = get_reqwest_client(); let hibp_client = get_reqwest_client();
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?; let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
// If we get a 404, return a 404, it means no breached accounts // If we get a 404, return a 404, it means no breached accounts
if res.status() == 404 { if res.status() == 404 {
return Err(Error::empty().with_code(404)); return Err(Error::empty().with_code(404));
} }
let value: Value = res.error_for_status()?.json()?; let value: Value = res.error_for_status()?.json().await?;
Ok(Json(value)) Ok(Json(value))
} else { } else {
Ok(Json(json!([{ Ok(Json(json!([{
@ -178,3 +184,19 @@ fn hibp_breach(username: String) -> JsonResult {
}]))) }])))
} }
} }
// We use DbConn here to let the alive healthcheck also verify the database connection.
#[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> {
now()
}
#[get("/now")]
pub fn now() -> Json<String> {
Json(crate::util::format_date(&chrono::Utc::now().naive_utc()))
}
#[get("/version")]
fn version() -> Json<&'static str> {
Json(crate::VERSION.unwrap_or_default())
}

458
src/api/core/organizations.rs

File diff suppressed because it is too large

189
src/api/core/sends.rs

@ -1,14 +1,15 @@
use std::{io::Read, path::Path}; use std::path::Path;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use multipart::server::{save::SavedData, Multipart, SaveResult}; use rocket::form::Form;
use rocket::{http::ContentType, response::NamedFile, Data}; use rocket::fs::NamedFile;
use rocket_contrib::json::Json; use rocket::fs::TempFile;
use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
api::{ApiResult, EmptyResult, JsonResult, JsonUpcase, Notify, NumberOrString, UpdateType}, api::{ApiResult, EmptyResult, JsonResult, JsonUpcase, Notify, NumberOrString, UpdateType},
auth::{Headers, Host}, auth::{ClientIp, Headers, Host},
db::{models::*, DbConn, DbPool}, db::{models::*, DbConn, DbPool},
util::SafeString, util::SafeString,
CONFIG, CONFIG,
@ -31,10 +32,10 @@ pub fn routes() -> Vec<rocket::Route> {
] ]
} }
pub fn purge_sends(pool: DbPool) { pub async fn purge_sends(pool: DbPool) {
debug!("Purging sends"); debug!("Purging sends");
if let Ok(conn) = pool.get() { if let Ok(conn) = pool.get().await {
Send::purge(&conn); Send::purge(&conn).await;
} else { } else {
error!("Failed to get DB connection while purging sends") error!("Failed to get DB connection while purging sends")
} }
@ -67,10 +68,10 @@ struct SendData {
/// ///
/// There is also a Vaultwarden-specific `sends_allowed` config setting that /// There is also a Vaultwarden-specific `sends_allowed` config setting that
/// controls this policy globally. /// controls this policy globally.
fn enforce_disable_send_policy(headers: &Headers, conn: &DbConn) -> EmptyResult { async fn enforce_disable_send_policy(headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_uuid = &headers.user.uuid; let user_uuid = &headers.user.uuid;
let policy_type = OrgPolicyType::DisableSend; let policy_type = OrgPolicyType::DisableSend;
if !CONFIG.sends_allowed() || OrgPolicy::is_applicable_to_user(user_uuid, policy_type, conn) { if !CONFIG.sends_allowed() || OrgPolicy::is_applicable_to_user(user_uuid, policy_type, conn).await {
err!("Due to an Enterprise Policy, you are only able to delete an existing Send.") err!("Due to an Enterprise Policy, you are only able to delete an existing Send.")
} }
Ok(()) Ok(())
@ -82,10 +83,10 @@ fn enforce_disable_send_policy(headers: &Headers, conn: &DbConn) -> EmptyResult
/// but is allowed to remove this option from an existing Send. /// but is allowed to remove this option from an existing Send.
/// ///
/// Ref: https://bitwarden.com/help/article/policies/#send-options /// Ref: https://bitwarden.com/help/article/policies/#send-options
fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &DbConn) -> EmptyResult { async fn enforce_disable_hide_email_policy(data: &SendData, headers: &Headers, conn: &DbConn) -> EmptyResult {
let user_uuid = &headers.user.uuid; let user_uuid = &headers.user.uuid;
let hide_email = data.HideEmail.unwrap_or(false); let hide_email = data.HideEmail.unwrap_or(false);
if hide_email && OrgPolicy::is_hide_email_disabled(user_uuid, conn) { if hide_email && OrgPolicy::is_hide_email_disabled(user_uuid, conn).await {
err!( err!(
"Due to an Enterprise Policy, you are not allowed to hide your email address \ "Due to an Enterprise Policy, you are not allowed to hide your email address \
from recipients when creating or editing a Send." from recipients when creating or editing a Send."
@ -134,9 +135,9 @@ fn create_send(data: SendData, user_uuid: String) -> ApiResult<Send> {
} }
#[get("/sends")] #[get("/sends")]
fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> { async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &conn); let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.iter().map(|s| s.to_json()).collect(); let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect();
Json(json!({ Json(json!({
"Data": sends_json, "Data": sends_json,
@ -146,8 +147,8 @@ fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
} }
#[get("/sends/<uuid>")] #[get("/sends/<uuid>")]
fn get_send(uuid: String, headers: Headers, conn: DbConn) -> JsonResult { async fn get_send(uuid: String, headers: Headers, conn: DbConn) -> JsonResult {
let send = match Send::find_by_uuid(&uuid, &conn) { let send = match Send::find_by_uuid(&uuid, &conn).await {
Some(send) => send, Some(send) => send,
None => err!("Send not found"), None => err!("Send not found"),
}; };
@ -160,42 +161,40 @@ fn get_send(uuid: String, headers: Headers, conn: DbConn) -> JsonResult {
} }
#[post("/sends", data = "<data>")] #[post("/sends", data = "<data>")]
fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn)?; enforce_disable_send_policy(&headers, &conn).await?;
let data: SendData = data.into_inner().data; let data: SendData = data.into_inner().data;
enforce_disable_hide_email_policy(&data, &headers, &conn)?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
if data.Type == SendType::File as i32 { if data.Type == SendType::File as i32 {
err!("File sends should use /api/sends/file") err!("File sends should use /api/sends/file")
} }
let mut send = create_send(data, headers.user.uuid)?; let mut send = create_send(data, headers.user.uuid)?;
send.save(&conn)?; send.save(&conn).await?;
nt.send_send_update(UpdateType::SyncSendCreate, &send, &send.update_users_revision(&conn)); nt.send_send_update(UpdateType::SyncSendCreate, &send, &send.update_users_revision(&conn).await).await;
Ok(Json(send.to_json())) Ok(Json(send.to_json()))
} }
#[post("/sends/file", format = "multipart/form-data", data = "<data>")] #[derive(FromForm)]
fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { struct UploadData<'f> {
enforce_disable_send_policy(&headers, &conn)?; model: Json<crate::util::UpCase<SendData>>,
data: TempFile<'f>,
let boundary = content_type.params().next().expect("No boundary provided").1; }
let mut mpart = Multipart::with_body(data.open(), boundary); #[post("/sends/file", format = "multipart/form-data", data = "<data>")]
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
// First entry is the SendData JSON let UploadData {
let mut model_entry = match mpart.read_entry()? { model,
Some(e) if &*e.headers.name == "model" => e, mut data,
Some(_) => err!("Invalid entry name"), } = data.into_inner();
None => err!("No model entry present"), let model = model.into_inner().data;
};
let mut buf = String::new(); enforce_disable_hide_email_policy(&model, &headers, &conn).await?;
model_entry.data.read_to_string(&mut buf)?;
let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
// Get the file length and add an extra 5% to avoid issues // Get the file length and add an extra 5% to avoid issues
const SIZE_525_MB: u64 = 550_502_400; const SIZE_525_MB: u64 = 550_502_400;
@ -203,7 +202,7 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
let size_limit = match CONFIG.user_attachment_limit() { let size_limit = match CONFIG.user_attachment_limit() {
Some(0) => err!("File uploads are disabled"), Some(0) => err!("File uploads are disabled"),
Some(limit_kb) => { Some(limit_kb) => {
let left = (limit_kb * 1024) - Attachment::size_by_user(&headers.user.uuid, &conn); let left = (limit_kb * 1024) - Attachment::size_by_user(&headers.user.uuid, &conn).await;
if left <= 0 { if left <= 0 {
err!("Attachment storage limit reached! Delete some attachments to free up space") err!("Attachment storage limit reached! Delete some attachments to free up space")
} }
@ -212,51 +211,47 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
None => SIZE_525_MB, None => SIZE_525_MB,
}; };
// Create the Send let mut send = create_send(model, headers.user.uuid)?;
let mut send = create_send(data.data, headers.user.uuid)?;
let file_id = crate::crypto::generate_send_id();
if send.atype != SendType::File as i32 { if send.atype != SendType::File as i32 {
err!("Send content is not a file"); err!("Send content is not a file");
} }
let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id); // There seems to be a bug somewhere regarding uploading attachments using the Android Client (Maybe iOS too?)
// See: https://github.com/dani-garcia/vaultwarden/issues/2644
// Since all other clients seem to match TempFile::File and not TempFile::Buffered lets catch this and return an error for now.
// We need to figure out how to solve this, but for now it's better to not accept these attachments since they will be broken.
if let TempFile::Buffered {
content: _,
} = &data
{
err!("Error reading send file data. Please try an other client.");
}
// Read the data entry and save the file let size = data.len();
let mut data_entry = match mpart.read_entry()? { if size > size_limit {
Some(e) if &*e.headers.name == "data" => e, err!("Attachment storage limit exceeded with this file");
Some(_) => err!("Invalid entry name"), }
None => err!("No model entry present"),
};
let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) { let file_id = crate::crypto::generate_send_id();
SaveResult::Full(SavedData::File(_, size)) => size as i32, let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
SaveResult::Full(other) => { let file_path = folder_path.join(&file_id);
std::fs::remove_file(&file_path).ok(); tokio::fs::create_dir_all(&folder_path).await?;
err!(format!("Attachment is not a file: {:?}", other));
} if let Err(_err) = data.persist_to(&file_path).await {
SaveResult::Partial(_, reason) => { data.move_copy_to(file_path).await?
std::fs::remove_file(&file_path).ok(); }
err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
}
SaveResult::Error(e) => {
std::fs::remove_file(&file_path).ok();
err!(format!("Error: {:?}", e));
}
};
// Set ID and sizes
let mut data_value: Value = serde_json::from_str(&send.data)?; let mut data_value: Value = serde_json::from_str(&send.data)?;
if let Some(o) = data_value.as_object_mut() { if let Some(o) = data_value.as_object_mut() {
o.insert(String::from("Id"), Value::String(file_id)); o.insert(String::from("Id"), Value::String(file_id));
o.insert(String::from("Size"), Value::Number(size.into())); o.insert(String::from("Size"), Value::Number(size.into()));
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size))); o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
} }
send.data = serde_json::to_string(&data_value)?; send.data = serde_json::to_string(&data_value)?;
// Save the changes in the database // Save the changes in the database
send.save(&conn)?; send.save(&conn).await?;
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn)); nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
Ok(Json(send.to_json())) Ok(Json(send.to_json()))
} }
@ -268,8 +263,8 @@ pub struct SendAccessData {
} }
#[post("/sends/access/<access_id>", data = "<data>")] #[post("/sends/access/<access_id>", data = "<data>")]
fn post_access(access_id: String, data: JsonUpcase<SendAccessData>, conn: DbConn) -> JsonResult { async fn post_access(access_id: String, data: JsonUpcase<SendAccessData>, conn: DbConn, ip: ClientIp) -> JsonResult {
let mut send = match Send::find_by_access_id(&access_id, &conn) { let mut send = match Send::find_by_access_id(&access_id, &conn).await {
Some(s) => s, Some(s) => s,
None => err_code!(SEND_INACCESSIBLE_MSG, 404), None => err_code!(SEND_INACCESSIBLE_MSG, 404),
}; };
@ -297,8 +292,8 @@ fn post_access(access_id: String, data: JsonUpcase<SendAccessData>, conn: DbConn
if send.password_hash.is_some() { if send.password_hash.is_some() {
match data.into_inner().data.Password { match data.into_inner().data.Password {
Some(ref p) if send.check_password(p) => { /* Nothing to do here */ } Some(ref p) if send.check_password(p) => { /* Nothing to do here */ }
Some(_) => err!("Invalid password."), Some(_) => err!("Invalid password", format!("IP: {}.", ip.ip)),
None => err_code!("Password not provided", 401), None => err_code!("Password not provided", format!("IP: {}.", ip.ip), 401),
} }
} }
@ -307,20 +302,20 @@ fn post_access(access_id: String, data: JsonUpcase<SendAccessData>, conn: DbConn
send.access_count += 1; send.access_count += 1;
} }
send.save(&conn)?; send.save(&conn).await?;
Ok(Json(send.to_json_access(&conn))) Ok(Json(send.to_json_access(&conn).await))
} }
#[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")] #[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")]
fn post_access_file( async fn post_access_file(
send_id: String, send_id: String,
file_id: String, file_id: String,
data: JsonUpcase<SendAccessData>, data: JsonUpcase<SendAccessData>,
host: Host, host: Host,
conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
let mut send = match Send::find_by_uuid(&send_id, &conn) { let mut send = match Send::find_by_uuid(&send_id, &conn).await {
Some(s) => s, Some(s) => s,
None => err_code!(SEND_INACCESSIBLE_MSG, 404), None => err_code!(SEND_INACCESSIBLE_MSG, 404),
}; };
@ -355,7 +350,7 @@ fn post_access_file(
send.access_count += 1; send.access_count += 1;
send.save(&conn)?; send.save(&conn).await?;
let token_claims = crate::auth::generate_send_claims(&send_id, &file_id); let token_claims = crate::auth::generate_send_claims(&send_id, &file_id);
let token = crate::auth::encode_jwt(&token_claims); let token = crate::auth::encode_jwt(&token_claims);
@ -367,23 +362,29 @@ fn post_access_file(
} }
#[get("/sends/<send_id>/<file_id>?<t>")] #[get("/sends/<send_id>/<file_id>?<t>")]
fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> { async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(&t) { if let Ok(claims) = crate::auth::decode_send(&t) {
if claims.sub == format!("{}/{}", send_id, file_id) { if claims.sub == format!("{}/{}", send_id, file_id) {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok(); return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
} }
} }
None None
} }
#[put("/sends/<id>", data = "<data>")] #[put("/sends/<id>", data = "<data>")]
fn put_send(id: String, data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn put_send(
enforce_disable_send_policy(&headers, &conn)?; id: String,
data: JsonUpcase<SendData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
enforce_disable_send_policy(&headers, &conn).await?;
let data: SendData = data.into_inner().data; let data: SendData = data.into_inner().data;
enforce_disable_hide_email_policy(&data, &headers, &conn)?; enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let mut send = match Send::find_by_uuid(&id, &conn) { let mut send = match Send::find_by_uuid(&id, &conn).await {
Some(s) => s, Some(s) => s,
None => err!("Send not found"), None => err!("Send not found"),
}; };
@ -430,15 +431,15 @@ fn put_send(id: String, data: JsonUpcase<SendData>, headers: Headers, conn: DbCo
send.set_password(Some(&password)); send.set_password(Some(&password));
} }
send.save(&conn)?; send.save(&conn).await?;
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn)); nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
Ok(Json(send.to_json())) Ok(Json(send.to_json()))
} }
#[delete("/sends/<id>")] #[delete("/sends/<id>")]
fn delete_send(id: String, headers: Headers, conn: DbConn, nt: Notify) -> EmptyResult { async fn delete_send(id: String, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let send = match Send::find_by_uuid(&id, &conn) { let send = match Send::find_by_uuid(&id, &conn).await {
Some(s) => s, Some(s) => s,
None => err!("Send not found"), None => err!("Send not found"),
}; };
@ -447,17 +448,17 @@ fn delete_send(id: String, headers: Headers, conn: DbConn, nt: Notify) -> EmptyR
err!("Send is not owned by user") err!("Send is not owned by user")
} }
send.delete(&conn)?; send.delete(&conn).await?;
nt.send_send_update(UpdateType::SyncSendDelete, &send, &send.update_users_revision(&conn)); nt.send_send_update(UpdateType::SyncSendDelete, &send, &send.update_users_revision(&conn).await).await;
Ok(()) Ok(())
} }
#[put("/sends/<id>/remove-password")] #[put("/sends/<id>/remove-password")]
fn put_remove_password(id: String, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult { async fn put_remove_password(id: String, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn)?; enforce_disable_send_policy(&headers, &conn).await?;
let mut send = match Send::find_by_uuid(&id, &conn) { let mut send = match Send::find_by_uuid(&id, &conn).await {
Some(s) => s, Some(s) => s,
None => err!("Send not found"), None => err!("Send not found"),
}; };
@ -467,8 +468,8 @@ fn put_remove_password(id: String, headers: Headers, conn: DbConn, nt: Notify) -
} }
send.set_password(None); send.set_password(None);
send.save(&conn)?; send.save(&conn).await?;
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn)); nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
Ok(Json(send.to_json())) Ok(Json(send.to_json()))
} }

45
src/api/core/two_factor/authenticator.rs

@ -1,6 +1,6 @@
use data_encoding::BASE32; use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use crate::{ use crate::{
api::{ api::{
@ -21,7 +21,7 @@ pub fn routes() -> Vec<Route> {
} }
#[post("/two-factor/get-authenticator", data = "<data>")] #[post("/two-factor/get-authenticator", data = "<data>")]
fn generate_authenticator(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn generate_authenticator(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordData = data.into_inner().data; let data: PasswordData = data.into_inner().data;
let user = headers.user; let user = headers.user;
@ -30,7 +30,7 @@ fn generate_authenticator(data: JsonUpcase<PasswordData>, headers: Headers, conn
} }
let type_ = TwoFactorType::Authenticator as i32; let type_ = TwoFactorType::Authenticator as i32;
let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn); let twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await;
let (enabled, key) = match twofactor { let (enabled, key) = match twofactor {
Some(tf) => (true, tf.data), Some(tf) => (true, tf.data),
@ -53,7 +53,7 @@ struct EnableAuthenticatorData {
} }
#[post("/two-factor/authenticator", data = "<data>")] #[post("/two-factor/authenticator", data = "<data>")]
fn activate_authenticator( async fn activate_authenticator(
data: JsonUpcase<EnableAuthenticatorData>, data: JsonUpcase<EnableAuthenticatorData>,
headers: Headers, headers: Headers,
ip: ClientIp, ip: ClientIp,
@ -81,9 +81,9 @@ fn activate_authenticator(
} }
// Validate the token provided with the key, and save new twofactor // Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &ip, &conn)?; validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &ip, &conn).await?;
_generate_recover_code(&mut user, &conn); _generate_recover_code(&mut user, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"Enabled": true, "Enabled": true,
@ -93,16 +93,16 @@ fn activate_authenticator(
} }
#[put("/two-factor/authenticator", data = "<data>")] #[put("/two-factor/authenticator", data = "<data>")]
fn activate_authenticator_put( async fn activate_authenticator_put(
data: JsonUpcase<EnableAuthenticatorData>, data: JsonUpcase<EnableAuthenticatorData>,
headers: Headers, headers: Headers,
ip: ClientIp, ip: ClientIp,
conn: DbConn, conn: DbConn,
) -> JsonResult { ) -> JsonResult {
activate_authenticator(data, headers, ip, conn) activate_authenticator(data, headers, ip, conn).await
} }
pub fn validate_totp_code_str( pub async fn validate_totp_code_str(
user_uuid: &str, user_uuid: &str,
totp_code: &str, totp_code: &str,
secret: &str, secret: &str,
@ -113,10 +113,16 @@ pub fn validate_totp_code_str(
err!("TOTP code is not a number"); err!("TOTP code is not a number");
} }
validate_totp_code(user_uuid, totp_code, secret, ip, conn) validate_totp_code(user_uuid, totp_code, secret, ip, conn).await
} }
pub fn validate_totp_code(user_uuid: &str, totp_code: &str, secret: &str, ip: &ClientIp, conn: &DbConn) -> EmptyResult { pub async fn validate_totp_code(
user_uuid: &str,
totp_code: &str,
secret: &str,
ip: &ClientIp,
conn: &DbConn,
) -> EmptyResult {
use totp_lite::{totp_custom, Sha1}; use totp_lite::{totp_custom, Sha1};
let decoded_secret = match BASE32.decode(secret.as_bytes()) { let decoded_secret = match BASE32.decode(secret.as_bytes()) {
@ -124,15 +130,16 @@ pub fn validate_totp_code(user_uuid: &str, totp_code: &str, secret: &str, ip: &C
Err(_) => err!("Invalid TOTP secret"), Err(_) => err!("Invalid TOTP secret"),
}; };
let mut twofactor = match TwoFactor::find_by_user_and_type(user_uuid, TwoFactorType::Authenticator as i32, conn) { let mut twofactor =
Some(tf) => tf, match TwoFactor::find_by_user_and_type(user_uuid, TwoFactorType::Authenticator as i32, conn).await {
_ => TwoFactor::new(user_uuid.to_string(), TwoFactorType::Authenticator, secret.to_string()), Some(tf) => tf,
}; _ => TwoFactor::new(user_uuid.to_string(), TwoFactorType::Authenticator, secret.to_string()),
};
// The amount of steps back and forward in time // The amount of steps back and forward in time
// Also check if we need to disable time drifted TOTP codes. // Also check if we need to disable time drifted TOTP codes.
// If that is the case, we set the steps to 0 so only the current TOTP is valid. // If that is the case, we set the steps to 0 so only the current TOTP is valid.
let steps = !CONFIG.authenticator_disable_time_drift() as i64; let steps = i64::from(!CONFIG.authenticator_disable_time_drift());
// Get the current system time in UNIX Epoch (UTC) // Get the current system time in UNIX Epoch (UTC)
let current_time = chrono::Utc::now(); let current_time = chrono::Utc::now();
@ -147,7 +154,7 @@ pub fn validate_totp_code(user_uuid: &str, totp_code: &str, secret: &str, ip: &C
let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time); let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time);
// Check the the given code equals the generated and if the time_step is larger then the one last used. // Check the the given code equals the generated and if the time_step is larger then the one last used.
if generated == totp_code && time_step > twofactor.last_used as i64 { if generated == totp_code && time_step > i64::from(twofactor.last_used) {
// If the step does not equals 0 the time is drifted either server or client side. // If the step does not equals 0 the time is drifted either server or client side.
if step != 0 { if step != 0 {
warn!("TOTP Time drift detected. The step offset is {}", step); warn!("TOTP Time drift detected. The step offset is {}", step);
@ -156,9 +163,9 @@ pub fn validate_totp_code(user_uuid: &str, totp_code: &str, secret: &str, ip: &C
// Save the last used time step so only totp time steps higher then this one are allowed. // Save the last used time step so only totp time steps higher then this one are allowed.
// This will also save a newly created twofactor if the code is correct. // This will also save a newly created twofactor if the code is correct.
twofactor.last_used = time_step as i32; twofactor.last_used = time_step as i32;
twofactor.save(conn)?; twofactor.save(conn).await?;
return Ok(()); return Ok(());
} else if generated == totp_code && time_step <= twofactor.last_used as i64 { } else if generated == totp_code && time_step <= i64::from(twofactor.last_used) {
warn!("This TOTP or a TOTP code within {} steps back or forward has already been used!", steps); warn!("This TOTP or a TOTP code within {} steps back or forward has already been used!", steps);
err!(format!("Invalid TOTP code! Server time: {} IP: {}", current_time.format("%F %T UTC"), ip.ip)); err!(format!("Invalid TOTP code! Server time: {} IP: {}", current_time.format("%F %T UTC"), ip.ip));
} }

46
src/api/core/two_factor/duo.rs

@ -1,7 +1,7 @@
use chrono::Utc; use chrono::Utc;
use data_encoding::BASE64; use data_encoding::BASE64;
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use crate::{ use crate::{
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData}, api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
@ -89,14 +89,14 @@ impl DuoStatus {
const DISABLED_MESSAGE_DEFAULT: &str = "<To use the global Duo keys, please leave these fields untouched>"; const DISABLED_MESSAGE_DEFAULT: &str = "<To use the global Duo keys, please leave these fields untouched>";
#[post("/two-factor/get-duo", data = "<data>")] #[post("/two-factor/get-duo", data = "<data>")]
fn get_duo(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn get_duo(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordData = data.into_inner().data; let data: PasswordData = data.into_inner().data;
if !headers.user.check_valid_password(&data.MasterPasswordHash) { if !headers.user.check_valid_password(&data.MasterPasswordHash) {
err!("Invalid password"); err!("Invalid password");
} }
let data = get_user_duo_data(&headers.user.uuid, &conn); let data = get_user_duo_data(&headers.user.uuid, &conn).await;
let (enabled, data) = match data { let (enabled, data) = match data {
DuoStatus::Global(_) => (true, Some(DuoData::secret())), DuoStatus::Global(_) => (true, Some(DuoData::secret())),
@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
} }
#[post("/two-factor/duo", data = "<data>")] #[post("/two-factor/duo", data = "<data>")]
fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableDuoData = data.into_inner().data; let data: EnableDuoData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
let (data, data_str) = if check_duo_fields_custom(&data) { let (data, data_str) = if check_duo_fields_custom(&data) {
let data_req: DuoData = data.into(); let data_req: DuoData = data.into();
let data_str = serde_json::to_string(&data_req)?; let data_str = serde_json::to_string(&data_req)?;
duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?; duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
(data_req.obscure(), data_str) (data_req.obscure(), data_str)
} else { } else {
(DuoData::secret(), String::new()) (DuoData::secret(), String::new())
@ -171,9 +171,9 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
let type_ = TwoFactorType::Duo; let type_ = TwoFactorType::Duo;
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str); let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&conn)?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn); _generate_recover_code(&mut user, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"Enabled": true, "Enabled": true,
@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
} }
#[put("/two-factor/duo", data = "<data>")] #[put("/two-factor/duo", data = "<data>")]
fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_duo(data, headers, conn) activate_duo(data, headers, conn).await
} }
fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult { async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
use reqwest::{header, Method}; use reqwest::{header, Method};
use std::str::FromStr; use std::str::FromStr;
@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
.basic_auth(username, Some(password)) .basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)") .header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, date) .header(header::DATE, date)
.send()? .send()
.await?
.error_for_status()?; .error_for_status()?;
Ok(()) Ok(())
@ -222,11 +223,11 @@ const AUTH_PREFIX: &str = "AUTH";
const DUO_PREFIX: &str = "TX"; const DUO_PREFIX: &str = "TX";
const APP_PREFIX: &str = "APP"; const APP_PREFIX: &str = "APP";
fn get_user_duo_data(uuid: &str, conn: &DbConn) -> DuoStatus { async fn get_user_duo_data(uuid: &str, conn: &DbConn) -> DuoStatus {
let type_ = TwoFactorType::Duo as i32; let type_ = TwoFactorType::Duo as i32;
// If the user doesn't have an entry, disabled // If the user doesn't have an entry, disabled
let twofactor = match TwoFactor::find_by_user_and_type(uuid, type_, conn) { let twofactor = match TwoFactor::find_by_user_and_type(uuid, type_, conn).await {
Some(t) => t, Some(t) => t,
None => return DuoStatus::Disabled(DuoData::global().is_some()), None => return DuoStatus::Disabled(DuoData::global().is_some()),
}; };
@ -246,19 +247,20 @@ fn get_user_duo_data(uuid: &str, conn: &DbConn) -> DuoStatus {
} }
// let (ik, sk, ak, host) = get_duo_keys(); // let (ik, sk, ak, host) = get_duo_keys();
fn get_duo_keys_email(email: &str, conn: &DbConn) -> ApiResult<(String, String, String, String)> { async fn get_duo_keys_email(email: &str, conn: &DbConn) -> ApiResult<(String, String, String, String)> {
let data = User::find_by_mail(email, conn) let data = match User::find_by_mail(email, conn).await {
.and_then(|u| get_user_duo_data(&u.uuid, conn).data()) Some(u) => get_user_duo_data(&u.uuid, conn).await.data(),
.or_else(DuoData::global) _ => DuoData::global(),
.map_res("Can't fetch Duo keys")?; }
.map_res("Can't fetch Duo Keys")?;
Ok((data.ik, data.sk, CONFIG.get_duo_akey(), data.host)) Ok((data.ik, data.sk, CONFIG.get_duo_akey(), data.host))
} }
pub fn generate_duo_signature(email: &str, conn: &DbConn) -> ApiResult<(String, String)> { pub async fn generate_duo_signature(email: &str, conn: &DbConn) -> ApiResult<(String, String)> {
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
let (ik, sk, ak, host) = get_duo_keys_email(email, conn)?; let (ik, sk, ak, host) = get_duo_keys_email(email, conn).await?;
let duo_sign = sign_duo_values(&sk, email, &ik, DUO_PREFIX, now + DUO_EXPIRE); let duo_sign = sign_duo_values(&sk, email, &ik, DUO_PREFIX, now + DUO_EXPIRE);
let app_sign = sign_duo_values(&ak, email, &ik, APP_PREFIX, now + APP_EXPIRE); let app_sign = sign_duo_values(&ak, email, &ik, APP_PREFIX, now + APP_EXPIRE);
@ -273,7 +275,7 @@ fn sign_duo_values(key: &str, email: &str, ikey: &str, prefix: &str, expire: i64
format!("{}|{}", cookie, crypto::hmac_sign(key, &cookie)) format!("{}|{}", cookie, crypto::hmac_sign(key, &cookie))
} }
pub fn validate_duo_login(email: &str, response: &str, conn: &DbConn) -> EmptyResult { pub async fn validate_duo_login(email: &str, response: &str, conn: &DbConn) -> EmptyResult {
// email is as entered by the user, so it needs to be normalized before // email is as entered by the user, so it needs to be normalized before
// comparison with auth_user below. // comparison with auth_user below.
let email = &email.to_lowercase(); let email = &email.to_lowercase();
@ -288,7 +290,7 @@ pub fn validate_duo_login(email: &str, response: &str, conn: &DbConn) -> EmptyRe
let now = Utc::now().timestamp(); let now = Utc::now().timestamp();
let (ik, sk, ak, _host) = get_duo_keys_email(email, conn)?; let (ik, sk, ak, _host) = get_duo_keys_email(email, conn).await?;
let auth_user = parse_duo_values(&sk, auth_sig, &ik, AUTH_PREFIX, now)?; let auth_user = parse_duo_values(&sk, auth_sig, &ik, AUTH_PREFIX, now)?;
let app_user = parse_duo_values(&ak, app_sig, &ik, APP_PREFIX, now)?; let app_user = parse_duo_values(&ak, app_sig, &ik, APP_PREFIX, now)?;

78
src/api/core/two_factor/email.rs

@ -1,6 +1,6 @@
use chrono::{Duration, NaiveDateTime, Utc}; use chrono::{Duration, NaiveDateTime, Utc};
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use crate::{ use crate::{
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData}, api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
@ -28,13 +28,13 @@ struct SendEmailLoginData {
/// User is trying to login and wants to use email 2FA. /// User is trying to login and wants to use email 2FA.
/// Does not require Bearer token /// Does not require Bearer token
#[post("/two-factor/send-email-login", data = "<data>")] // JsonResult #[post("/two-factor/send-email-login", data = "<data>")] // JsonResult
fn send_email_login(data: JsonUpcase<SendEmailLoginData>, conn: DbConn) -> EmptyResult { async fn send_email_login(data: JsonUpcase<SendEmailLoginData>, conn: DbConn) -> EmptyResult {
let data: SendEmailLoginData = data.into_inner().data; let data: SendEmailLoginData = data.into_inner().data;
use crate::db::models::User; use crate::db::models::User;
// Get the user // Get the user
let user = match User::find_by_mail(&data.Email, &conn) { let user = match User::find_by_mail(&data.Email, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Username or password is incorrect. Try again."), None => err!("Username or password is incorrect. Try again."),
}; };
@ -48,31 +48,32 @@ fn send_email_login(data: JsonUpcase<SendEmailLoginData>, conn: DbConn) -> Empty
err!("Email 2FA is disabled") err!("Email 2FA is disabled")
} }
send_token(&user.uuid, &conn)?; send_token(&user.uuid, &conn).await?;
Ok(()) Ok(())
} }
/// Generate the token, save the data for later verification and send email to user /// Generate the token, save the data for later verification and send email to user
pub fn send_token(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn send_token(user_uuid: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::Email as i32; let type_ = TwoFactorType::Email as i32;
let mut twofactor = TwoFactor::find_by_user_and_type(user_uuid, type_, conn).map_res("Two factor not found")?; let mut twofactor =
TwoFactor::find_by_user_and_type(user_uuid, type_, conn).await.map_res("Two factor not found")?;
let generated_token = crypto::generate_token(CONFIG.email_token_size())?; let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
let mut twofactor_data = EmailTokenData::from_json(&twofactor.data)?; let mut twofactor_data = EmailTokenData::from_json(&twofactor.data)?;
twofactor_data.set_token(generated_token); twofactor_data.set_token(generated_token);
twofactor.data = twofactor_data.to_json(); twofactor.data = twofactor_data.to_json();
twofactor.save(conn)?; twofactor.save(conn).await?;
mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?)?; mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?).await?;
Ok(()) Ok(())
} }
/// When user clicks on Manage email 2FA show the user the related information /// When user clicks on Manage email 2FA show the user the related information
#[post("/two-factor/get-email", data = "<data>")] #[post("/two-factor/get-email", data = "<data>")]
fn get_email(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn get_email(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordData = data.into_inner().data; let data: PasswordData = data.into_inner().data;
let user = headers.user; let user = headers.user;
@ -80,13 +81,14 @@ fn get_email(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) ->
err!("Invalid password"); err!("Invalid password");
} }
let (enabled, mfa_email) = match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &conn) { let (enabled, mfa_email) =
Some(x) => { match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::Email as i32, &conn).await {
let twofactor_data = EmailTokenData::from_json(&x.data)?; Some(x) => {
(true, json!(twofactor_data.email)) let twofactor_data = EmailTokenData::from_json(&x.data)?;
} (true, json!(twofactor_data.email))
_ => (false, json!(null)), }
}; _ => (false, json!(null)),
};
Ok(Json(json!({ Ok(Json(json!({
"Email": mfa_email, "Email": mfa_email,
@ -105,7 +107,7 @@ struct SendEmailData {
/// Send a verification email to the specified email address to check whether it exists/belongs to user. /// Send a verification email to the specified email address to check whether it exists/belongs to user.
#[post("/two-factor/send-email", data = "<data>")] #[post("/two-factor/send-email", data = "<data>")]
fn send_email(data: JsonUpcase<SendEmailData>, headers: Headers, conn: DbConn) -> EmptyResult { async fn send_email(data: JsonUpcase<SendEmailData>, headers: Headers, conn: DbConn) -> EmptyResult {
let data: SendEmailData = data.into_inner().data; let data: SendEmailData = data.into_inner().data;
let user = headers.user; let user = headers.user;
@ -119,18 +121,18 @@ fn send_email(data: JsonUpcase<SendEmailData>, headers: Headers, conn: DbConn) -
let type_ = TwoFactorType::Email as i32; let type_ = TwoFactorType::Email as i32;
if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn) { if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
tf.delete(&conn)?; tf.delete(&conn).await?;
} }
let generated_token = crypto::generate_token(CONFIG.email_token_size())?; let generated_token = crypto::generate_email_token(CONFIG.email_token_size());
let twofactor_data = EmailTokenData::new(data.Email, generated_token); let twofactor_data = EmailTokenData::new(data.Email, generated_token);
// Uses EmailVerificationChallenge as type to show that it's not verified yet. // Uses EmailVerificationChallenge as type to show that it's not verified yet.
let twofactor = TwoFactor::new(user.uuid, TwoFactorType::EmailVerificationChallenge, twofactor_data.to_json()); let twofactor = TwoFactor::new(user.uuid, TwoFactorType::EmailVerificationChallenge, twofactor_data.to_json());
twofactor.save(&conn)?; twofactor.save(&conn).await?;
mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?)?; mail::send_token(&twofactor_data.email, &twofactor_data.last_token.map_res("Token is empty")?).await?;
Ok(()) Ok(())
} }
@ -145,7 +147,7 @@ struct EmailData {
/// Verify email belongs to user and can be used for 2FA email codes. /// Verify email belongs to user and can be used for 2FA email codes.
#[put("/two-factor/email", data = "<data>")] #[put("/two-factor/email", data = "<data>")]
fn email(data: JsonUpcase<EmailData>, headers: Headers, conn: DbConn) -> JsonResult { async fn email(data: JsonUpcase<EmailData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EmailData = data.into_inner().data; let data: EmailData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -154,7 +156,8 @@ fn email(data: JsonUpcase<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
} }
let type_ = TwoFactorType::EmailVerificationChallenge as i32; let type_ = TwoFactorType::EmailVerificationChallenge as i32;
let mut twofactor = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).map_res("Two factor not found")?; let mut twofactor =
TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await.map_res("Two factor not found")?;
let mut email_data = EmailTokenData::from_json(&twofactor.data)?; let mut email_data = EmailTokenData::from_json(&twofactor.data)?;
@ -170,9 +173,9 @@ fn email(data: JsonUpcase<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
email_data.reset_token(); email_data.reset_token();
twofactor.atype = TwoFactorType::Email as i32; twofactor.atype = TwoFactorType::Email as i32;
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(&conn)?; twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn); _generate_recover_code(&mut user, &conn).await;
Ok(Json(json!({ Ok(Json(json!({
"Email": email_data.email, "Email": email_data.email,
@ -182,9 +185,10 @@ fn email(data: JsonUpcase<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
} }
/// Validate the email code when used as TwoFactor token mechanism /// Validate the email code when used as TwoFactor token mechanism
pub fn validate_email_code_str(user_uuid: &str, token: &str, data: &str, conn: &DbConn) -> EmptyResult { pub async fn validate_email_code_str(user_uuid: &str, token: &str, data: &str, conn: &DbConn) -> EmptyResult {
let mut email_data = EmailTokenData::from_json(data)?; let mut email_data = EmailTokenData::from_json(data)?;
let mut twofactor = TwoFactor::find_by_user_and_type(user_uuid, TwoFactorType::Email as i32, conn) let mut twofactor = TwoFactor::find_by_user_and_type(user_uuid, TwoFactorType::Email as i32, conn)
.await
.map_res("Two factor not found")?; .map_res("Two factor not found")?;
let issued_token = match &email_data.last_token { let issued_token = match &email_data.last_token {
Some(t) => t, Some(t) => t,
@ -197,14 +201,14 @@ pub fn validate_email_code_str(user_uuid: &str, token: &str, data: &str, conn: &
email_data.reset_token(); email_data.reset_token();
} }
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(conn)?; twofactor.save(conn).await?;
err!("Token is invalid") err!("Token is invalid")
} }
email_data.reset_token(); email_data.reset_token();
twofactor.data = email_data.to_json(); twofactor.data = email_data.to_json();
twofactor.save(conn)?; twofactor.save(conn).await?;
let date = NaiveDateTime::from_timestamp(email_data.token_sent, 0); let date = NaiveDateTime::from_timestamp(email_data.token_sent, 0);
let max_time = CONFIG.email_expiration_time() as i64; let max_time = CONFIG.email_expiration_time() as i64;
@ -309,18 +313,4 @@ mod tests {
// If it's smaller than 3 characters it should only show asterisks. // If it's smaller than 3 characters it should only show asterisks.
assert_eq!(result, "***@example.ext"); assert_eq!(result, "***@example.ext");
} }
#[test]
fn test_token() {
let result = crypto::generate_token(19).unwrap();
assert_eq!(result.chars().count(), 19);
}
#[test]
fn test_token_too_large() {
let result = crypto::generate_token(20);
assert!(result.is_err(), "too large token should give an error");
}
} }

59
src/api/core/two_factor/mod.rs

@ -1,7 +1,7 @@
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
use data_encoding::BASE32; use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
@ -15,7 +15,6 @@ use crate::{
pub mod authenticator; pub mod authenticator;
pub mod duo; pub mod duo;
pub mod email; pub mod email;
pub mod u2f;
pub mod webauthn; pub mod webauthn;
pub mod yubikey; pub mod yubikey;
@ -25,7 +24,6 @@ pub fn routes() -> Vec<Route> {
routes.append(&mut authenticator::routes()); routes.append(&mut authenticator::routes());
routes.append(&mut duo::routes()); routes.append(&mut duo::routes());
routes.append(&mut email::routes()); routes.append(&mut email::routes());
routes.append(&mut u2f::routes());
routes.append(&mut webauthn::routes()); routes.append(&mut webauthn::routes());
routes.append(&mut yubikey::routes()); routes.append(&mut yubikey::routes());
@ -33,8 +31,8 @@ pub fn routes() -> Vec<Route> {
} }
#[get("/two-factor")] #[get("/two-factor")]
fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> { async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &conn); let twofactors = TwoFactor::find_by_user(&headers.user.uuid, &conn).await;
let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect(); let twofactors_json: Vec<Value> = twofactors.iter().map(TwoFactor::to_json_provider).collect();
Json(json!({ Json(json!({
@ -68,13 +66,13 @@ struct RecoverTwoFactor {
} }
#[post("/two-factor/recover", data = "<data>")] #[post("/two-factor/recover", data = "<data>")]
fn recover(data: JsonUpcase<RecoverTwoFactor>, conn: DbConn) -> JsonResult { async fn recover(data: JsonUpcase<RecoverTwoFactor>, conn: DbConn) -> JsonResult {
let data: RecoverTwoFactor = data.into_inner().data; let data: RecoverTwoFactor = data.into_inner().data;
use crate::db::models::User; use crate::db::models::User;
// Get the user // Get the user
let mut user = match User::find_by_mail(&data.Email, &conn) { let mut user = match User::find_by_mail(&data.Email, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Username or password is incorrect. Try again."), None => err!("Username or password is incorrect. Try again."),
}; };
@ -90,19 +88,19 @@ fn recover(data: JsonUpcase<RecoverTwoFactor>, conn: DbConn) -> JsonResult {
} }
// Remove all twofactors from the user // Remove all twofactors from the user
TwoFactor::delete_all_by_user(&user.uuid, &conn)?; TwoFactor::delete_all_by_user(&user.uuid, &conn).await?;
// Remove the recovery code, not needed without twofactors // Remove the recovery code, not needed without twofactors
user.totp_recover = None; user.totp_recover = None;
user.save(&conn)?; user.save(&conn).await?;
Ok(Json(json!({}))) Ok(Json(json!({})))
} }
fn _generate_recover_code(user: &mut User, conn: &DbConn) { async fn _generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() { if user.totp_recover.is_none() {
let totp_recover = BASE32.encode(&crypto::get_random(vec![0u8; 20])); let totp_recover = BASE32.encode(&crypto::get_random(vec![0u8; 20]));
user.totp_recover = Some(totp_recover); user.totp_recover = Some(totp_recover);
user.save(conn).ok(); user.save(conn).await.ok();
} }
} }
@ -114,7 +112,7 @@ struct DisableTwoFactorData {
} }
#[post("/two-factor/disable", data = "<data>")] #[post("/two-factor/disable", data = "<data>")]
fn disable_twofactor(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult { async fn disable_twofactor(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: DisableTwoFactorData = data.into_inner().data; let data: DisableTwoFactorData = data.into_inner().data;
let password_hash = data.MasterPasswordHash; let password_hash = data.MasterPasswordHash;
let user = headers.user; let user = headers.user;
@ -125,23 +123,24 @@ fn disable_twofactor(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, c
let type_ = data.Type.into_i32()?; let type_ = data.Type.into_i32()?;
if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn) { if let Some(twofactor) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
twofactor.delete(&conn)?; twofactor.delete(&conn).await?;
} }
let twofactor_disabled = TwoFactor::find_by_user(&user.uuid, &conn).is_empty(); let twofactor_disabled = TwoFactor::find_by_user(&user.uuid, &conn).await.is_empty();
if twofactor_disabled { if twofactor_disabled {
let policy_type = OrgPolicyType::TwoFactorAuthentication; for user_org in
let org_list = UserOrganization::find_by_user_and_policy(&user.uuid, policy_type, &conn); UserOrganization::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, &conn)
.await
for user_org in org_list.into_iter() { .into_iter()
{
if user_org.atype < UserOrgType::Admin { if user_org.atype < UserOrgType::Admin {
if CONFIG.mail_enabled() { if CONFIG.mail_enabled() {
let org = Organization::find_by_uuid(&user_org.org_uuid, &conn).unwrap(); let org = Organization::find_by_uuid(&user_org.org_uuid, &conn).await.unwrap();
mail::send_2fa_removed_from_org(&user.email, &org.name)?; mail::send_2fa_removed_from_org(&user.email, &org.name).await?;
} }
user_org.delete(&conn)?; user_org.delete(&conn).await?;
} }
} }
} }
@ -154,18 +153,18 @@ fn disable_twofactor(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, c
} }
#[put("/two-factor/disable", data = "<data>")] #[put("/two-factor/disable", data = "<data>")]
fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult { async fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Headers, conn: DbConn) -> JsonResult {
disable_twofactor(data, headers, conn) disable_twofactor(data, headers, conn).await
} }
pub fn send_incomplete_2fa_notifications(pool: DbPool) { pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
debug!("Sending notifications for incomplete 2FA logins"); debug!("Sending notifications for incomplete 2FA logins");
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return; return;
} }
let conn = match pool.get() { let conn = match pool.get().await {
Ok(conn) => conn, Ok(conn) => conn,
_ => { _ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
@ -175,15 +174,17 @@ pub fn send_incomplete_2fa_notifications(pool: DbPool) {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
let time_limit = Duration::minutes(CONFIG.incomplete_2fa_time_limit()); let time_limit = Duration::minutes(CONFIG.incomplete_2fa_time_limit());
let incomplete_logins = TwoFactorIncomplete::find_logins_before(&(now - time_limit), &conn); let time_before = now - time_limit;
let incomplete_logins = TwoFactorIncomplete::find_logins_before(&time_before, &conn).await;
for login in incomplete_logins { for login in incomplete_logins {
let user = User::find_by_uuid(&login.user_uuid, &conn).expect("User not found"); let user = User::find_by_uuid(&login.user_uuid, &conn).await.expect("User not found");
info!( info!(
"User {} did not complete a 2FA login within the configured time limit. IP: {}", "User {} did not complete a 2FA login within the configured time limit. IP: {}",
user.email, login.ip_address user.email, login.ip_address
); );
mail::send_incomplete_2fa_login(&user.email, &login.ip_address, &login.login_time, &login.device_name) mail::send_incomplete_2fa_login(&user.email, &login.ip_address, &login.login_time, &login.device_name)
.await
.expect("Error sending incomplete 2FA email"); .expect("Error sending incomplete 2FA email");
login.delete(&conn).expect("Error deleting incomplete 2FA record"); login.delete(&conn).await.expect("Error deleting incomplete 2FA record");
} }
} }

352
src/api/core/two_factor/u2f.rs

@ -1,352 +0,0 @@
use once_cell::sync::Lazy;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use u2f::{
messages::{RegisterResponse, SignResponse, U2fSignRequest},
protocol::{Challenge, U2f},
register::Registration,
};
use crate::{
api::{
core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, NumberOrString,
PasswordData,
},
auth::Headers,
db::{
models::{TwoFactor, TwoFactorType},
DbConn,
},
error::Error,
CONFIG,
};
const U2F_VERSION: &str = "U2F_V2";
static APP_ID: Lazy<String> = Lazy::new(|| format!("{}/app-id.json", &CONFIG.domain()));
static U2F: Lazy<U2f> = Lazy::new(|| U2f::new(APP_ID.clone()));
pub fn routes() -> Vec<Route> {
routes![generate_u2f, generate_u2f_challenge, activate_u2f, activate_u2f_put, delete_u2f,]
}
#[post("/two-factor/get-u2f", data = "<data>")]
fn generate_u2f(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
if !CONFIG.domain_set() {
err!("`DOMAIN` environment variable is not set. U2F disabled")
}
let data: PasswordData = data.into_inner().data;
if !headers.user.check_valid_password(&data.MasterPasswordHash) {
err!("Invalid password");
}
let (enabled, keys) = get_u2f_registrations(&headers.user.uuid, &conn)?;
let keys_json: Vec<Value> = keys.iter().map(U2FRegistration::to_json).collect();
Ok(Json(json!({
"Enabled": enabled,
"Keys": keys_json,
"Object": "twoFactorU2f"
})))
}
#[post("/two-factor/get-u2f-challenge", data = "<data>")]
fn generate_u2f_challenge(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordData = data.into_inner().data;
if !headers.user.check_valid_password(&data.MasterPasswordHash) {
err!("Invalid password");
}
let _type = TwoFactorType::U2fRegisterChallenge;
let challenge = _create_u2f_challenge(&headers.user.uuid, _type, &conn).challenge;
Ok(Json(json!({
"UserId": headers.user.uuid,
"AppId": APP_ID.to_string(),
"Challenge": challenge,
"Version": U2F_VERSION,
})))
}
#[derive(Deserialize, Debug)]
#[allow(non_snake_case)]
struct EnableU2FData {
Id: NumberOrString,
// 1..5
Name: String,
MasterPasswordHash: String,
DeviceResponse: String,
}
// This struct is referenced from the U2F lib
// because it doesn't implement Deserialize
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[serde(remote = "Registration")]
struct RegistrationDef {
key_handle: Vec<u8>,
pub_key: Vec<u8>,
attestation_cert: Option<Vec<u8>>,
device_name: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct U2FRegistration {
pub id: i32,
pub name: String,
#[serde(with = "RegistrationDef")]
pub reg: Registration,
pub counter: u32,
compromised: bool,
pub migrated: Option<bool>,
}
impl U2FRegistration {
fn to_json(&self) -> Value {
json!({
"Id": self.id,
"Name": self.name,
"Compromised": self.compromised,
})
}
}
// This struct is copied from the U2F lib
// to add an optional error code
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct RegisterResponseCopy {
pub registration_data: String,
pub version: String,
pub client_data: String,
pub error_code: Option<NumberOrString>,
}
impl From<RegisterResponseCopy> for RegisterResponse {
fn from(r: RegisterResponseCopy) -> RegisterResponse {
RegisterResponse {
registration_data: r.registration_data,
version: r.version,
client_data: r.client_data,
}
}
}
#[post("/two-factor/u2f", data = "<data>")]
fn activate_u2f(data: JsonUpcase<EnableU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableU2FData = data.into_inner().data;
let mut user = headers.user;
if !user.check_valid_password(&data.MasterPasswordHash) {
err!("Invalid password");
}
let tf_type = TwoFactorType::U2fRegisterChallenge as i32;
let tf_challenge = match TwoFactor::find_by_user_and_type(&user.uuid, tf_type, &conn) {
Some(c) => c,
None => err!("Can't recover challenge"),
};
let challenge: Challenge = serde_json::from_str(&tf_challenge.data)?;
tf_challenge.delete(&conn)?;
let response: RegisterResponseCopy = serde_json::from_str(&data.DeviceResponse)?;
let error_code = response.error_code.clone().map_or("0".into(), NumberOrString::into_string);
if error_code != "0" {
err!("Error registering U2F token")
}
let registration = U2F.register_response(challenge, response.into())?;
let full_registration = U2FRegistration {
id: data.Id.into_i32()?,
name: data.Name,
reg: registration,
compromised: false,
counter: 0,
migrated: None,
};
let mut regs = get_u2f_registrations(&user.uuid, &conn)?.1;
// TODO: Check that there is no repeat Id
regs.push(full_registration);
save_u2f_registrations(&user.uuid, &regs, &conn)?;
_generate_recover_code(&mut user, &conn);
let keys_json: Vec<Value> = regs.iter().map(U2FRegistration::to_json).collect();
Ok(Json(json!({
"Enabled": true,
"Keys": keys_json,
"Object": "twoFactorU2f"
})))
}
#[put("/two-factor/u2f", data = "<data>")]
fn activate_u2f_put(data: JsonUpcase<EnableU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_u2f(data, headers, conn)
}
#[derive(Deserialize, Debug)]
#[allow(non_snake_case)]
struct DeleteU2FData {
Id: NumberOrString,
MasterPasswordHash: String,
}
#[delete("/two-factor/u2f", data = "<data>")]
fn delete_u2f(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: DeleteU2FData = data.into_inner().data;
let id = data.Id.into_i32()?;
if !headers.user.check_valid_password(&data.MasterPasswordHash) {
err!("Invalid password");
}
let type_ = TwoFactorType::U2f as i32;
let mut tf = match TwoFactor::find_by_user_and_type(&headers.user.uuid, type_, &conn) {
Some(tf) => tf,
None => err!("U2F data not found!"),
};
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&tf.data) {
Ok(d) => d,
Err(_) => err!("Error parsing U2F data"),
};
data.retain(|r| r.id != id);
let new_data_str = serde_json::to_string(&data)?;
tf.data = new_data_str;
tf.save(&conn)?;
let keys_json: Vec<Value> = data.iter().map(U2FRegistration::to_json).collect();
Ok(Json(json!({
"Enabled": true,
"Keys": keys_json,
"Object": "twoFactorU2f"
})))
}
fn _create_u2f_challenge(user_uuid: &str, type_: TwoFactorType, conn: &DbConn) -> Challenge {
let challenge = U2F.generate_challenge().unwrap();
TwoFactor::new(user_uuid.into(), type_, serde_json::to_string(&challenge).unwrap())
.save(conn)
.expect("Error saving challenge");
challenge
}
fn save_u2f_registrations(user_uuid: &str, regs: &[U2FRegistration], conn: &DbConn) -> EmptyResult {
TwoFactor::new(user_uuid.into(), TwoFactorType::U2f, serde_json::to_string(regs)?).save(conn)
}
fn get_u2f_registrations(user_uuid: &str, conn: &DbConn) -> Result<(bool, Vec<U2FRegistration>), Error> {
let type_ = TwoFactorType::U2f as i32;
let (enabled, regs) = match TwoFactor::find_by_user_and_type(user_uuid, type_, conn) {
Some(tf) => (tf.enabled, tf.data),
None => return Ok((false, Vec::new())), // If no data, return empty list
};
let data = match serde_json::from_str(&regs) {
Ok(d) => d,
Err(_) => {
// If error, try old format
let mut old_regs = _old_parse_registrations(&regs);
if old_regs.len() != 1 {
err!("The old U2F format only allows one device")
}
// Convert to new format
let new_regs = vec![U2FRegistration {
id: 1,
name: "Unnamed U2F key".into(),
reg: old_regs.remove(0),
compromised: false,
counter: 0,
migrated: None,
}];
// Save new format
save_u2f_registrations(user_uuid, &new_regs, conn)?;
new_regs
}
};
Ok((enabled, data))
}
fn _old_parse_registrations(registations: &str) -> Vec<Registration> {
#[derive(Deserialize)]
struct Helper(#[serde(with = "RegistrationDef")] Registration);
let regs: Vec<Value> = serde_json::from_str(registations).expect("Can't parse Registration data");
regs.into_iter().map(|r| serde_json::from_value(r).unwrap()).map(|Helper(r)| r).collect()
}
pub fn generate_u2f_login(user_uuid: &str, conn: &DbConn) -> ApiResult<U2fSignRequest> {
let challenge = _create_u2f_challenge(user_uuid, TwoFactorType::U2fLoginChallenge, conn);
let registrations: Vec<_> = get_u2f_registrations(user_uuid, conn)?.1.into_iter().map(|r| r.reg).collect();
if registrations.is_empty() {
err!("No U2F devices registered")
}
Ok(U2F.sign_request(challenge, registrations))
}
pub fn validate_u2f_login(user_uuid: &str, response: &str, conn: &DbConn) -> EmptyResult {
let challenge_type = TwoFactorType::U2fLoginChallenge as i32;
let tf_challenge = TwoFactor::find_by_user_and_type(user_uuid, challenge_type, conn);
let challenge = match tf_challenge {
Some(tf_challenge) => {
let challenge: Challenge = serde_json::from_str(&tf_challenge.data)?;
tf_challenge.delete(conn)?;
challenge
}
None => err!("Can't recover login challenge"),
};
let response: SignResponse = serde_json::from_str(response)?;
let mut registrations = get_u2f_registrations(user_uuid, conn)?.1;
if registrations.is_empty() {
err!("No U2F devices registered")
}
for reg in &mut registrations {
let response = U2F.sign_response(challenge.clone(), reg.reg.clone(), response.clone(), reg.counter);
match response {
Ok(new_counter) => {
reg.counter = new_counter;
save_u2f_registrations(user_uuid, &registrations, conn)?;
return Ok(());
}
Err(u2f::u2ferror::U2fError::CounterTooLow) => {
reg.compromised = true;
save_u2f_registrations(user_uuid, &registrations, conn)?;
err!("This device might be compromised!");
}
Err(e) => {
warn!("E {:#}", e);
// break;
}
}
}
err!("error verifying response")
}

91
src/api/core/two_factor/webauthn.rs

@ -1,5 +1,5 @@
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use url::Url; use url::Url;
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn}; use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
@ -21,6 +21,28 @@ pub fn routes() -> Vec<Route> {
routes![get_webauthn, generate_webauthn_challenge, activate_webauthn, activate_webauthn_put, delete_webauthn,] routes![get_webauthn, generate_webauthn_challenge, activate_webauthn, activate_webauthn_put, delete_webauthn,]
} }
// Some old u2f structs still needed for migrating from u2f to WebAuthn
// Both `struct Registration` and `struct U2FRegistration` can be removed if we remove the u2f to WebAuthn migration
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Registration {
pub key_handle: Vec<u8>,
pub pub_key: Vec<u8>,
pub attestation_cert: Option<Vec<u8>>,
pub device_name: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct U2FRegistration {
pub id: i32,
pub name: String,
#[serde(with = "Registration")]
pub reg: Registration,
pub counter: u32,
compromised: bool,
pub migrated: Option<bool>,
}
struct WebauthnConfig { struct WebauthnConfig {
url: String, url: String,
origin: Url, origin: Url,
@ -80,7 +102,7 @@ impl WebauthnRegistration {
} }
#[post("/two-factor/get-webauthn", data = "<data>")] #[post("/two-factor/get-webauthn", data = "<data>")]
fn get_webauthn(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn get_webauthn(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
if !CONFIG.domain_set() { if !CONFIG.domain_set() {
err!("`DOMAIN` environment variable is not set. Webauthn disabled") err!("`DOMAIN` environment variable is not set. Webauthn disabled")
} }
@ -89,7 +111,7 @@ fn get_webauthn(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn)
err!("Invalid password"); err!("Invalid password");
} }
let (enabled, registrations) = get_webauthn_registrations(&headers.user.uuid, &conn)?; let (enabled, registrations) = get_webauthn_registrations(&headers.user.uuid, &conn).await?;
let registrations_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect(); let registrations_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
@ -100,12 +122,13 @@ fn get_webauthn(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn)
} }
#[post("/two-factor/get-webauthn-challenge", data = "<data>")] #[post("/two-factor/get-webauthn-challenge", data = "<data>")]
fn generate_webauthn_challenge(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn generate_webauthn_challenge(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
if !headers.user.check_valid_password(&data.data.MasterPasswordHash) { if !headers.user.check_valid_password(&data.data.MasterPasswordHash) {
err!("Invalid password"); err!("Invalid password");
} }
let registrations = get_webauthn_registrations(&headers.user.uuid, &conn)? let registrations = get_webauthn_registrations(&headers.user.uuid, &conn)
.await?
.1 .1
.into_iter() .into_iter()
.map(|r| r.credential.cred_id) // We return the credentialIds to the clients to avoid double registering .map(|r| r.credential.cred_id) // We return the credentialIds to the clients to avoid double registering
@ -121,7 +144,7 @@ fn generate_webauthn_challenge(data: JsonUpcase<PasswordData>, headers: Headers,
)?; )?;
let type_ = TwoFactorType::WebauthnRegisterChallenge; let type_ = TwoFactorType::WebauthnRegisterChallenge;
TwoFactor::new(headers.user.uuid, type_, serde_json::to_string(&state)?).save(&conn)?; TwoFactor::new(headers.user.uuid, type_, serde_json::to_string(&state)?).save(&conn).await?;
let mut challenge_value = serde_json::to_value(challenge.public_key)?; let mut challenge_value = serde_json::to_value(challenge.public_key)?;
challenge_value["status"] = "ok".into(); challenge_value["status"] = "ok".into();
@ -218,7 +241,7 @@ impl From<PublicKeyCredentialCopy> for PublicKeyCredential {
} }
#[post("/two-factor/webauthn", data = "<data>")] #[post("/two-factor/webauthn", data = "<data>")]
fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableWebauthnData = data.into_inner().data; let data: EnableWebauthnData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -228,10 +251,10 @@ fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, con
// Retrieve and delete the saved challenge state // Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32; let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn) { let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => { Some(tf) => {
let state: RegistrationState = serde_json::from_str(&tf.data)?; let state: RegistrationState = serde_json::from_str(&tf.data)?;
tf.delete(&conn)?; tf.delete(&conn).await?;
state state
} }
None => err!("Can't recover challenge"), None => err!("Can't recover challenge"),
@ -241,7 +264,7 @@ fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, con
let (credential, _data) = let (credential, _data) =
WebauthnConfig::load().register_credential(&data.DeviceResponse.into(), &state, |_| Ok(false))?; WebauthnConfig::load().register_credential(&data.DeviceResponse.into(), &state, |_| Ok(false))?;
let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &conn)?.1; let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &conn).await?.1;
// TODO: Check for repeated ID's // TODO: Check for repeated ID's
registrations.push(WebauthnRegistration { registrations.push(WebauthnRegistration {
id: data.Id.into_i32()?, id: data.Id.into_i32()?,
@ -252,8 +275,10 @@ fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, con
}); });
// Save the registrations and return them // Save the registrations and return them
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?).save(&conn)?; TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
_generate_recover_code(&mut user, &conn); .save(&conn)
.await?;
_generate_recover_code(&mut user, &conn).await;
let keys_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect(); let keys_json: Vec<Value> = registrations.iter().map(WebauthnRegistration::to_json).collect();
Ok(Json(json!({ Ok(Json(json!({
@ -264,8 +289,8 @@ fn activate_webauthn(data: JsonUpcase<EnableWebauthnData>, headers: Headers, con
} }
#[put("/two-factor/webauthn", data = "<data>")] #[put("/two-factor/webauthn", data = "<data>")]
fn activate_webauthn_put(data: JsonUpcase<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_webauthn_put(data: JsonUpcase<EnableWebauthnData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_webauthn(data, headers, conn) activate_webauthn(data, headers, conn).await
} }
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
@ -276,13 +301,14 @@ struct DeleteU2FData {
} }
#[delete("/two-factor/webauthn", data = "<data>")] #[delete("/two-factor/webauthn", data = "<data>")]
fn delete_webauthn(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbConn) -> JsonResult { async fn delete_webauthn(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbConn) -> JsonResult {
let id = data.data.Id.into_i32()?; let id = data.data.Id.into_i32()?;
if !headers.user.check_valid_password(&data.data.MasterPasswordHash) { if !headers.user.check_valid_password(&data.data.MasterPasswordHash) {
err!("Invalid password"); err!("Invalid password");
} }
let mut tf = match TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &conn) { let mut tf = match TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::Webauthn as i32, &conn).await
{
Some(tf) => tf, Some(tf) => tf,
None => err!("Webauthn data not found!"), None => err!("Webauthn data not found!"),
}; };
@ -296,12 +322,12 @@ fn delete_webauthn(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbCo
let removed_item = data.remove(item_pos); let removed_item = data.remove(item_pos);
tf.data = serde_json::to_string(&data)?; tf.data = serde_json::to_string(&data)?;
tf.save(&conn)?; tf.save(&conn).await?;
drop(tf); drop(tf);
// If entry is migrated from u2f, delete the u2f entry as well // If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn) { if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
use crate::api::core::two_factor::u2f::U2FRegistration; {
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) { let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) {
Ok(d) => d, Ok(d) => d,
Err(_) => err!("Error parsing U2F data"), Err(_) => err!("Error parsing U2F data"),
@ -311,7 +337,7 @@ fn delete_webauthn(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbCo
let new_data_str = serde_json::to_string(&data)?; let new_data_str = serde_json::to_string(&data)?;
u2f.data = new_data_str; u2f.data = new_data_str;
u2f.save(&conn)?; u2f.save(&conn).await?;
} }
let keys_json: Vec<Value> = data.iter().map(WebauthnRegistration::to_json).collect(); let keys_json: Vec<Value> = data.iter().map(WebauthnRegistration::to_json).collect();
@ -323,18 +349,21 @@ fn delete_webauthn(data: JsonUpcase<DeleteU2FData>, headers: Headers, conn: DbCo
}))) })))
} }
pub fn get_webauthn_registrations(user_uuid: &str, conn: &DbConn) -> Result<(bool, Vec<WebauthnRegistration>), Error> { pub async fn get_webauthn_registrations(
user_uuid: &str,
conn: &DbConn,
) -> Result<(bool, Vec<WebauthnRegistration>), Error> {
let type_ = TwoFactorType::Webauthn as i32; let type_ = TwoFactorType::Webauthn as i32;
match TwoFactor::find_by_user_and_type(user_uuid, type_, conn) { match TwoFactor::find_by_user_and_type(user_uuid, type_, conn).await {
Some(tf) => Ok((tf.enabled, serde_json::from_str(&tf.data)?)), Some(tf) => Ok((tf.enabled, serde_json::from_str(&tf.data)?)),
None => Ok((false, Vec::new())), // If no data, return empty list None => Ok((false, Vec::new())), // If no data, return empty list
} }
} }
pub fn generate_webauthn_login(user_uuid: &str, conn: &DbConn) -> JsonResult { pub async fn generate_webauthn_login(user_uuid: &str, conn: &DbConn) -> JsonResult {
// Load saved credentials // Load saved credentials
let creds: Vec<Credential> = let creds: Vec<Credential> =
get_webauthn_registrations(user_uuid, conn)?.1.into_iter().map(|r| r.credential).collect(); get_webauthn_registrations(user_uuid, conn).await?.1.into_iter().map(|r| r.credential).collect();
if creds.is_empty() { if creds.is_empty() {
err!("No Webauthn devices registered") err!("No Webauthn devices registered")
@ -346,18 +375,19 @@ pub fn generate_webauthn_login(user_uuid: &str, conn: &DbConn) -> JsonResult {
// Save the challenge state for later validation // Save the challenge state for later validation
TwoFactor::new(user_uuid.into(), TwoFactorType::WebauthnLoginChallenge, serde_json::to_string(&state)?) TwoFactor::new(user_uuid.into(), TwoFactorType::WebauthnLoginChallenge, serde_json::to_string(&state)?)
.save(conn)?; .save(conn)
.await?;
// Return challenge to the clients // Return challenge to the clients
Ok(Json(serde_json::to_value(response.public_key)?)) Ok(Json(serde_json::to_value(response.public_key)?))
} }
pub fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &DbConn) -> EmptyResult { pub async fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32; let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(user_uuid, type_, conn) { let state = match TwoFactor::find_by_user_and_type(user_uuid, type_, conn).await {
Some(tf) => { Some(tf) => {
let state: AuthenticationState = serde_json::from_str(&tf.data)?; let state: AuthenticationState = serde_json::from_str(&tf.data)?;
tf.delete(conn)?; tf.delete(conn).await?;
state state
} }
None => err!("Can't recover login challenge"), None => err!("Can't recover login challenge"),
@ -366,7 +396,7 @@ pub fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &DbConn) -
let rsp: crate::util::UpCase<PublicKeyCredentialCopy> = serde_json::from_str(response)?; let rsp: crate::util::UpCase<PublicKeyCredentialCopy> = serde_json::from_str(response)?;
let rsp: PublicKeyCredential = rsp.data.into(); let rsp: PublicKeyCredential = rsp.data.into();
let mut registrations = get_webauthn_registrations(user_uuid, conn)?.1; let mut registrations = get_webauthn_registrations(user_uuid, conn).await?.1;
// If the credential we received is migrated from U2F, enable the U2F compatibility // If the credential we received is migrated from U2F, enable the U2F compatibility
//let use_u2f = registrations.iter().any(|r| r.migrated && r.credential.cred_id == rsp.raw_id.0); //let use_u2f = registrations.iter().any(|r| r.migrated && r.credential.cred_id == rsp.raw_id.0);
@ -377,7 +407,8 @@ pub fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &DbConn) -
reg.credential.counter = auth_data.counter; reg.credential.counter = auth_data.counter;
TwoFactor::new(user_uuid.to_string(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?) TwoFactor::new(user_uuid.to_string(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(conn)?; .save(conn)
.await?;
return Ok(()); return Ok(());
} }
} }

27
src/api/core/two_factor/yubikey.rs

@ -1,5 +1,5 @@
use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use yubico::{config::Config, verify}; use yubico::{config::Config, verify};
@ -78,7 +78,7 @@ fn verify_yubikey_otp(otp: String) -> EmptyResult {
} }
#[post("/two-factor/get-yubikey", data = "<data>")] #[post("/two-factor/get-yubikey", data = "<data>")]
fn generate_yubikey(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult { async fn generate_yubikey(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbConn) -> JsonResult {
// Make sure the credentials are set // Make sure the credentials are set
get_yubico_credentials()?; get_yubico_credentials()?;
@ -92,7 +92,7 @@ fn generate_yubikey(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbCo
let user_uuid = &user.uuid; let user_uuid = &user.uuid;
let yubikey_type = TwoFactorType::YubiKey as i32; let yubikey_type = TwoFactorType::YubiKey as i32;
let r = TwoFactor::find_by_user_and_type(user_uuid, yubikey_type, &conn); let r = TwoFactor::find_by_user_and_type(user_uuid, yubikey_type, &conn).await;
if let Some(r) = r { if let Some(r) = r {
let yubikey_metadata: YubikeyMetadata = serde_json::from_str(&r.data)?; let yubikey_metadata: YubikeyMetadata = serde_json::from_str(&r.data)?;
@ -113,7 +113,7 @@ fn generate_yubikey(data: JsonUpcase<PasswordData>, headers: Headers, conn: DbCo
} }
#[post("/two-factor/yubikey", data = "<data>")] #[post("/two-factor/yubikey", data = "<data>")]
fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableYubikeyData = data.into_inner().data; let data: EnableYubikeyData = data.into_inner().data;
let mut user = headers.user; let mut user = headers.user;
@ -122,10 +122,11 @@ fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn:
} }
// Check if we already have some data // Check if we already have some data
let mut yubikey_data = match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &conn) { let mut yubikey_data =
Some(data) => data, match TwoFactor::find_by_user_and_type(&user.uuid, TwoFactorType::YubiKey as i32, &conn).await {
None => TwoFactor::new(user.uuid.clone(), TwoFactorType::YubiKey, String::new()), Some(data) => data,
}; None => TwoFactor::new(user.uuid.clone(), TwoFactorType::YubiKey, String::new()),
};
let yubikeys = parse_yubikeys(&data); let yubikeys = parse_yubikeys(&data);
@ -146,7 +147,7 @@ fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn:
verify_yubikey_otp(yubikey.to_owned()).map_res("Invalid Yubikey OTP provided")?; verify_yubikey_otp(yubikey.to_owned()).map_res("Invalid Yubikey OTP provided")?;
} }
let yubikey_ids: Vec<String> = yubikeys.into_iter().map(|x| (&x[..12]).to_owned()).collect(); let yubikey_ids: Vec<String> = yubikeys.into_iter().map(|x| (x[..12]).to_owned()).collect();
let yubikey_metadata = YubikeyMetadata { let yubikey_metadata = YubikeyMetadata {
Keys: yubikey_ids, Keys: yubikey_ids,
@ -154,9 +155,9 @@ fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn:
}; };
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap(); yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&conn)?; yubikey_data.save(&conn).await?;
_generate_recover_code(&mut user, &conn); _generate_recover_code(&mut user, &conn).await;
let mut result = jsonify_yubikeys(yubikey_metadata.Keys); let mut result = jsonify_yubikeys(yubikey_metadata.Keys);
@ -168,8 +169,8 @@ fn activate_yubikey(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn:
} }
#[put("/two-factor/yubikey", data = "<data>")] #[put("/two-factor/yubikey", data = "<data>")]
fn activate_yubikey_put(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult { async fn activate_yubikey_put(data: JsonUpcase<EnableYubikeyData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_yubikey(data, headers, conn) activate_yubikey(data, headers, conn).await
} }
pub fn validate_yubikey_login(response: &str, twofactor_data: &str) -> EmptyResult { pub fn validate_yubikey_login(response: &str, twofactor_data: &str) -> EmptyResult {

627
src/api/icons.rs

@ -1,21 +1,26 @@
use std::{ use std::{
collections::HashMap, net::IpAddr,
fs::{create_dir_all, remove_file, symlink_metadata, File}, sync::Arc,
io::prelude::*,
net::{IpAddr, ToSocketAddrs},
sync::{Arc, RwLock},
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use bytes::{Bytes, BytesMut};
use futures::{stream::StreamExt, TryFutureExt};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use regex::Regex; use regex::Regex;
use reqwest::{blocking::Client, blocking::Response, header}; use reqwest::{
use rocket::{ header::{self, HeaderMap, HeaderValue},
http::ContentType, Client, Response,
response::{Content, Redirect}, };
Route, use rocket::{http::ContentType, response::Redirect, Route};
use tokio::{
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::{AsyncReadExt, AsyncWriteExt},
net::lookup_host,
}; };
use html5gum::{Emitter, EndTag, HtmlString, InfallibleTokenizer, Readable, StartTag, StringReader, Tokenizer};
use crate::{ use crate::{
error::Error, error::Error,
util::{get_reqwest_client_builder, Cached}, util::{get_reqwest_client_builder, Cached},
@ -25,48 +30,56 @@ use crate::{
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
match CONFIG.icon_service().as_str() { match CONFIG.icon_service().as_str() {
"internal" => routes![icon_internal], "internal" => routes![icon_internal],
"bitwarden" => routes![icon_bitwarden], _ => routes![icon_external],
"duckduckgo" => routes![icon_duckduckgo],
"google" => routes![icon_google],
_ => routes![icon_custom],
} }
} }
static CLIENT: Lazy<Client> = Lazy::new(|| { static CLIENT: Lazy<Client> = Lazy::new(|| {
// Generate the default headers // Generate the default headers
let mut default_headers = header::HeaderMap::new(); let mut default_headers = HeaderMap::new();
default_headers default_headers.insert(header::USER_AGENT, HeaderValue::from_static("Links (2.22; Linux X86_64; GNU C; text)"));
.insert(header::USER_AGENT, header::HeaderValue::from_static("Links (2.22; Linux X86_64; GNU C; text)")); default_headers.insert(header::ACCEPT, HeaderValue::from_static("text/html, text/*;q=0.5, image/*, */*;q=0.1"));
default_headers default_headers.insert(header::ACCEPT_LANGUAGE, HeaderValue::from_static("en,*;q=0.1"));
.insert(header::ACCEPT, header::HeaderValue::from_static("text/html, text/*;q=0.5, image/*, */*;q=0.1")); default_headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
default_headers.insert(header::ACCEPT_LANGUAGE, header::HeaderValue::from_static("en,*;q=0.1")); default_headers.insert(header::PRAGMA, HeaderValue::from_static("no-cache"));
default_headers.insert(header::CACHE_CONTROL, header::HeaderValue::from_static("no-cache"));
default_headers.insert(header::PRAGMA, header::HeaderValue::from_static("no-cache")); // Generate the cookie store
let cookie_store = Arc::new(Jar::default());
// Reuse the client between requests // Reuse the client between requests
get_reqwest_client_builder() let client = get_reqwest_client_builder()
.cookie_provider(Arc::new(Jar::default())) .cookie_provider(Arc::clone(&cookie_store))
.timeout(Duration::from_secs(CONFIG.icon_download_timeout())) .timeout(Duration::from_secs(CONFIG.icon_download_timeout()))
.default_headers(default_headers) .default_headers(default_headers.clone());
.build()
.expect("Failed to build icon client") match client.build() {
Ok(client) => client,
Err(e) => {
error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
get_reqwest_client_builder()
.cookie_provider(cookie_store)
.timeout(Duration::from_secs(CONFIG.icon_download_timeout()))
.default_headers(default_headers)
.trust_dns(false)
.build()
.expect("Failed to build client")
}
}
}); });
// Build Regex only once since this takes a lot of time. // Build Regex only once since this takes a lot of time.
static ICON_REL_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)icon$|apple.*icon").unwrap());
static ICON_REL_BLACKLIST: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?i)mask-icon").unwrap());
static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap()); static ICON_SIZE_REGEX: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?x)(\d+)\D*(\d+)").unwrap());
// Special HashMap which holds the user defined Regex to speedup matching the regex. // Special HashMap which holds the user defined Regex to speedup matching the regex.
static ICON_BLACKLIST_REGEX: Lazy<RwLock<HashMap<String, Regex>>> = Lazy::new(|| RwLock::new(HashMap::new())); static ICON_BLACKLIST_REGEX: Lazy<dashmap::DashMap<String, Regex>> = Lazy::new(dashmap::DashMap::new);
fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> { async fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
if !is_valid_domain(domain) { if !is_valid_domain(domain) {
warn!("Invalid domain: {}", domain); warn!("Invalid domain: {}", domain);
return None; return None;
} }
if is_domain_blacklisted(domain) { if is_domain_blacklisted(domain).await {
return None; return None;
} }
@ -84,47 +97,28 @@ fn icon_redirect(domain: &str, template: &str) -> Option<Redirect> {
} }
#[get("/<domain>/icon.png")] #[get("/<domain>/icon.png")]
fn icon_custom(domain: String) -> Option<Redirect> { async fn icon_external(domain: String) -> Option<Redirect> {
icon_redirect(&domain, &CONFIG.icon_service()) icon_redirect(&domain, &CONFIG._icon_service_url()).await
}
#[get("/<domain>/icon.png")]
fn icon_bitwarden(domain: String) -> Option<Redirect> {
icon_redirect(&domain, "https://icons.bitwarden.net/{}/icon.png")
} }
#[get("/<domain>/icon.png")] #[get("/<domain>/icon.png")]
fn icon_duckduckgo(domain: String) -> Option<Redirect> { async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
icon_redirect(&domain, "https://icons.duckduckgo.com/ip3/{}.ico")
}
#[get("/<domain>/icon.png")]
fn icon_google(domain: String) -> Option<Redirect> {
icon_redirect(&domain, "https://www.google.com/s2/favicons?domain={}&sz=32")
}
#[get("/<domain>/icon.png")]
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png"); const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
if !is_valid_domain(&domain) { if !is_valid_domain(&domain) {
warn!("Invalid domain: {}", domain); warn!("Invalid domain: {}", domain);
return Cached::ttl( return Cached::ttl(
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), (ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(), CONFIG.icon_cache_negttl(),
true, true,
); );
} }
match get_icon(&domain) { match get_icon(&domain).await {
Some((icon, icon_type)) => { Some((icon, icon_type)) => {
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true) Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
} }
_ => Cached::ttl( _ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
),
} }
} }
@ -264,68 +258,57 @@ mod tests {
} }
} }
fn is_domain_blacklisted(domain: &str) -> bool { use cached::proc_macro::cached;
let mut is_blacklisted = CONFIG.icon_blacklist_non_global_ips() #[cached(key = "String", convert = r#"{ domain.to_string() }"#, size = 16, time = 60)]
&& (domain, 0) #[allow(clippy::unused_async)] // This is needed because cached causes a false-positive here.
.to_socket_addrs() async fn is_domain_blacklisted(domain: &str) -> bool {
.map(|x| { if CONFIG.icon_blacklist_non_global_ips() {
for ip_port in x { if let Ok(s) = lookup_host((domain, 0)).await {
if !is_global(ip_port.ip()) { for addr in s {
warn!("IP {} for domain '{}' is not a global IP!", ip_port.ip(), domain); if !is_global(addr.ip()) {
return true; debug!("IP {} for domain '{}' is not a global IP!", addr.ip(), domain);
} return true;
} }
false }
}) }
.unwrap_or(false); }
// Skip the regex check if the previous one is true already
if !is_blacklisted {
if let Some(blacklist) = CONFIG.icon_blacklist_regex() {
let mut regex_hashmap = ICON_BLACKLIST_REGEX.read().unwrap();
// Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it.
let regex = if let Some(regex) = regex_hashmap.get(&blacklist) {
regex
} else {
drop(regex_hashmap);
let mut regex_hashmap_write = ICON_BLACKLIST_REGEX.write().unwrap(); if let Some(blacklist) = CONFIG.icon_blacklist_regex() {
// Clear the current list if the previous key doesn't exists. // Use the pre-generate Regex stored in a Lazy HashMap if there's one, else generate it.
// To prevent growing of the HashMap after someone has changed it via the admin interface. let is_match = if let Some(regex) = ICON_BLACKLIST_REGEX.get(&blacklist) {
if regex_hashmap_write.len() >= 1 { regex.is_match(domain)
regex_hashmap_write.clear(); } else {
} // Clear the current list if the previous key doesn't exists.
// To prevent growing of the HashMap after someone has changed it via the admin interface.
if ICON_BLACKLIST_REGEX.len() >= 1 {
ICON_BLACKLIST_REGEX.clear();
}
// Generate the regex to store in too the Lazy Static HashMap. // Generate the regex to store in too the Lazy Static HashMap.
let blacklist_regex = Regex::new(&blacklist).unwrap(); let blacklist_regex = Regex::new(&blacklist).unwrap();
regex_hashmap_write.insert(blacklist.to_string(), blacklist_regex); let is_match = blacklist_regex.is_match(domain);
drop(regex_hashmap_write); ICON_BLACKLIST_REGEX.insert(blacklist.clone(), blacklist_regex);
regex_hashmap = ICON_BLACKLIST_REGEX.read().unwrap(); is_match
regex_hashmap.get(&blacklist).unwrap() };
};
// Use the pre-generate Regex stored in a Lazy HashMap. if is_match {
if regex.is_match(domain) { debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain);
debug!("Blacklisted domain: {} matched ICON_BLACKLIST_REGEX", domain); return true;
is_blacklisted = true;
}
} }
} }
false
is_blacklisted
} }
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> { async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain); let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
// Check for expiration of negatively cached copy // Check for expiration of negatively cached copy
if icon_is_negcached(&path) { if icon_is_negcached(&path).await {
return None; return None;
} }
if let Some(icon) = get_cached_icon(&path) { if let Some(icon) = get_cached_icon(&path).await {
let icon_type = match get_icon_type(&icon) { let icon_type = match get_icon_type(&icon) {
Some(x) => x, Some(x) => x,
_ => "x-icon", _ => "x-icon",
@ -338,31 +321,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
} }
// Get the icon, or None in case of error // Get the icon, or None in case of error
match download_icon(domain) { match download_icon(domain).await {
Ok((icon, icon_type)) => { Ok((icon, icon_type)) => {
save_icon(&path, &icon); save_icon(&path, &icon).await;
Some((icon, icon_type.unwrap_or("x-icon").to_string())) Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
} }
Err(e) => { Err(e) => {
warn!("Unable to download icon: {:?}", e); warn!("Unable to download icon: {:?}", e);
let miss_indicator = path + ".miss"; let miss_indicator = path + ".miss";
save_icon(&miss_indicator, &[]); save_icon(&miss_indicator, &[]).await;
None None
} }
} }
} }
fn get_cached_icon(path: &str) -> Option<Vec<u8>> { async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
// Check for expiration of successfully cached copy // Check for expiration of successfully cached copy
if icon_is_expired(path) { if icon_is_expired(path).await {
return None; return None;
} }
// Try to read the cached icon, and return it if it exists // Try to read the cached icon, and return it if it exists
if let Ok(mut f) = File::open(path) { if let Ok(mut f) = File::open(path).await {
let mut buffer = Vec::new(); let mut buffer = Vec::new();
if f.read_to_end(&mut buffer).is_ok() { if f.read_to_end(&mut buffer).await.is_ok() {
return Some(buffer); return Some(buffer);
} }
} }
@ -370,22 +353,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
None None
} }
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> { async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
let meta = symlink_metadata(path)?; let meta = symlink_metadata(path).await?;
let modified = meta.modified()?; let modified = meta.modified()?;
let age = SystemTime::now().duration_since(modified)?; let age = SystemTime::now().duration_since(modified)?;
Ok(ttl > 0 && ttl <= age.as_secs()) Ok(ttl > 0 && ttl <= age.as_secs())
} }
fn icon_is_negcached(path: &str) -> bool { async fn icon_is_negcached(path: &str) -> bool {
let miss_indicator = path.to_owned() + ".miss"; let miss_indicator = path.to_owned() + ".miss";
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()); let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
match expired { match expired {
// No longer negatively cached, drop the marker // No longer negatively cached, drop the marker
Ok(true) => { Ok(true) => {
if let Err(e) = remove_file(&miss_indicator) { if let Err(e) = remove_file(&miss_indicator).await {
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e); error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
} }
false false
@ -397,8 +380,8 @@ fn icon_is_negcached(path: &str) -> bool {
} }
} }
fn icon_is_expired(path: &str) -> bool { async fn icon_is_expired(path: &str) -> bool {
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()); let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
expired.unwrap_or(true) expired.unwrap_or(true)
} }
@ -416,91 +399,62 @@ impl Icon {
} }
} }
/// Iterates over the HTML document to find <base href="http://domain.tld"> fn get_favicons_node(
/// When found it will stop the iteration and the found base href will be shared deref via `base_href`. dom: InfallibleTokenizer<StringReader<'_>, FaviconEmitter>,
/// icons: &mut Vec<Icon>,
/// # Arguments url: &url::Url,
/// * `node` - A Parsed HTML document via html5ever::parse_document() ) {
/// * `base_href` - a mutable url::Url which will be overwritten when a base href tag has been found. const TAG_LINK: &[u8] = b"link";
/// const TAG_BASE: &[u8] = b"base";
fn get_base_href(node: &std::rc::Rc<markup5ever_rcdom::Node>, base_href: &mut url::Url) -> bool { const TAG_HEAD: &[u8] = b"head";
if let markup5ever_rcdom::NodeData::Element { const ATTR_REL: &[u8] = b"rel";
name, const ATTR_HREF: &[u8] = b"href";
attrs, const ATTR_SIZES: &[u8] = b"sizes";
..
} = &node.data let mut base_url = url.clone();
{ let mut icon_tags: Vec<StartTag> = Vec::new();
if name.local.as_ref() == "base" { for token in dom {
let attrs = attrs.borrow(); match token {
for attr in attrs.iter() { FaviconToken::StartTag(tag) => {
let attr_name = attr.name.local.as_ref(); if *tag.name == TAG_LINK
let attr_value = attr.value.as_ref(); && tag.attributes.contains_key(ATTR_REL)
&& tag.attributes.contains_key(ATTR_HREF)
if attr_name == "href" {
debug!("Found base href: {}", attr_value);
*base_href = match base_href.join(attr_value) {
Ok(href) => href,
_ => base_href.clone(),
};
return true;
}
}
return true;
}
}
// TODO: Might want to limit the recursion depth?
for child in node.children.borrow().iter() {
// Check if we got a true back and stop the iter.
// This means we found a <base> tag and can stop processing the html.
if get_base_href(child, base_href) {
return true;
}
}
false
}
fn get_favicons_node(node: &std::rc::Rc<markup5ever_rcdom::Node>, icons: &mut Vec<Icon>, url: &url::Url) {
if let markup5ever_rcdom::NodeData::Element {
name,
attrs,
..
} = &node.data
{
if name.local.as_ref() == "link" {
let mut has_rel = false;
let mut href = None;
let mut sizes = None;
let attrs = attrs.borrow();
for attr in attrs.iter() {
let attr_name = attr.name.local.as_ref();
let attr_value = attr.value.as_ref();
if attr_name == "rel" && ICON_REL_REGEX.is_match(attr_value) && !ICON_REL_BLACKLIST.is_match(attr_value)
{ {
has_rel = true; let rel_value = std::str::from_utf8(tag.attributes.get(ATTR_REL).unwrap())
} else if attr_name == "href" { .unwrap_or_default()
href = Some(attr_value); .to_ascii_lowercase();
} else if attr_name == "sizes" { if rel_value.contains("icon") && !rel_value.contains("mask-icon") {
sizes = Some(attr_value); icon_tags.push(tag);
}
} else if *tag.name == TAG_BASE && tag.attributes.contains_key(ATTR_HREF) {
let href = std::str::from_utf8(tag.attributes.get(ATTR_HREF).unwrap()).unwrap_or_default();
debug!("Found base href: {href}");
base_url = match base_url.join(href) {
Ok(inner_url) => inner_url,
_ => url.clone(),
};
} }
} }
FaviconToken::EndTag(tag) => {
if has_rel { if *tag.name == TAG_HEAD {
if let Some(inner_href) = href { break;
if let Ok(full_href) = url.join(inner_href).map(String::from) {
let priority = get_icon_priority(&full_href, sizes);
icons.push(Icon::new(priority, full_href));
}
} }
} }
} }
} }
// TODO: Might want to limit the recursion depth? for icon_tag in icon_tags {
for child in node.children.borrow().iter() { if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) {
get_favicons_node(child, icons, url); if let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) {
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default()
} else {
""
};
let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string()));
}
};
} }
} }
@ -518,16 +472,16 @@ struct IconUrlResult {
/// ///
/// # Example /// # Example
/// ``` /// ```
/// let icon_result = get_icon_url("github.com")?; /// let icon_result = get_icon_url("github.com").await?;
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?; /// let icon_result = get_icon_url("vaultwarden.discourse.group").await?;
/// ``` /// ```
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> { async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// Default URL with secure and insecure schemes // Default URL with secure and insecure schemes
let ssldomain = format!("https://{}", domain); let ssldomain = format!("https://{domain}");
let httpdomain = format!("http://{}", domain); let httpdomain = format!("http://{domain}");
// First check the domain as given during the request for both HTTPS and HTTP. // First check the domain as given during the request for both HTTPS and HTTP.
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) { let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
Ok(c) => Ok(c), Ok(c) => Ok(c),
Err(e) => { Err(e) => {
let mut sub_resp = Err(e); let mut sub_resp = Err(e);
@ -542,25 +496,24 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
base = domain_parts.next_back().unwrap() base = domain_parts.next_back().unwrap()
); );
if is_valid_domain(&base_domain) { if is_valid_domain(&base_domain) {
let sslbase = format!("https://{}", base_domain); let sslbase = format!("https://{base_domain}");
let httpbase = format!("http://{}", base_domain); let httpbase = format!("http://{base_domain}");
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain); debug!("[get_icon_url]: Trying without subdomains '{base_domain}'");
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)); sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
} }
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it. // When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
} else if is_ip.is_err() && domain.matches('.').count() < 2 { } else if is_ip.is_err() && domain.matches('.').count() < 2 {
let www_domain = format!("www.{}", domain); let www_domain = format!("www.{domain}");
if is_valid_domain(&www_domain) { if is_valid_domain(&www_domain) {
let sslwww = format!("https://{}", www_domain); let sslwww = format!("https://{www_domain}");
let httpwww = format!("http://{}", www_domain); let httpwww = format!("http://{www_domain}");
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain); debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'");
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)); sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
} }
} }
sub_resp sub_resp
} }
}; };
@ -575,26 +528,23 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// Set the referer to be used on the final request, some sites check this. // Set the referer to be used on the final request, some sites check this.
// Mostly used to prevent direct linking and other security resons. // Mostly used to prevent direct linking and other security resons.
referer = url.as_str().to_string(); referer = url.to_string();
// Add the default favicon.ico to the list with the domain the content responded from. // Add the fallback favicon.ico and apple-touch-icon.png to the list with the domain the content responded from.
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap()))); iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
iconlist.push(Icon::new(40, String::from(url.join("/apple-touch-icon.png").unwrap())));
// 384KB should be more than enough for the HTML, though as we only really need the HTML header. // 384KB should be more than enough for the HTML, though as we only really need the HTML header.
let mut limited_reader = content.take(384 * 1024); let limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.to_vec();
use html5ever::tendril::TendrilSink;
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
.from_utf8()
.read_from(&mut limited_reader)?;
let mut base_url: url::Url = url; let dom = Tokenizer::new_with_emitter(limited_reader.to_reader(), FaviconEmitter::default()).infallible();
get_base_href(&dom.document, &mut base_url); get_favicons_node(dom, &mut iconlist, &url);
get_favicons_node(&dom.document, &mut iconlist, &base_url);
} else { } else {
// Add the default favicon.ico to the list with just the given domain // Add the default favicon.ico to the list with just the given domain
iconlist.push(Icon::new(35, format!("{}/favicon.ico", ssldomain))); iconlist.push(Icon::new(35, format!("{ssldomain}/favicon.ico")));
iconlist.push(Icon::new(35, format!("{}/favicon.ico", httpdomain))); iconlist.push(Icon::new(40, format!("{ssldomain}/apple-touch-icon.png")));
iconlist.push(Icon::new(35, format!("{httpdomain}/favicon.ico")));
iconlist.push(Icon::new(40, format!("{httpdomain}/apple-touch-icon.png")));
} }
// Sort the iconlist by priority // Sort the iconlist by priority
@ -607,12 +557,12 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
}) })
} }
fn get_page(url: &str) -> Result<Response, Error> { async fn get_page(url: &str) -> Result<Response, Error> {
get_page_with_referer(url, "") get_page_with_referer(url, "").await
} }
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> { async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) { if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()).await {
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url); warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
} }
@ -621,7 +571,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
client = client.header("Referer", referer) client = client.header("Referer", referer)
} }
match client.send() { match client.send().await {
Ok(c) => c.error_for_status().map_err(Into::into), Ok(c) => c.error_for_status().map_err(Into::into),
Err(e) => err_silent!(format!("{}", e)), Err(e) => err_silent!(format!("{}", e)),
} }
@ -639,7 +589,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
/// priority1 = get_icon_priority("http://example.com/path/to/a/favicon.png", "32x32"); /// priority1 = get_icon_priority("http://example.com/path/to/a/favicon.png", "32x32");
/// priority2 = get_icon_priority("https://example.com/path/to/a/favicon.ico", ""); /// priority2 = get_icon_priority("https://example.com/path/to/a/favicon.ico", "");
/// ``` /// ```
fn get_icon_priority(href: &str, sizes: Option<&str>) -> u8 { fn get_icon_priority(href: &str, sizes: &str) -> u8 {
// Check if there is a dimension set // Check if there is a dimension set
let (width, height) = parse_sizes(sizes); let (width, height) = parse_sizes(sizes);
@ -687,11 +637,11 @@ fn get_icon_priority(href: &str, sizes: Option<&str>) -> u8 {
/// let (width, height) = parse_sizes("x128x128"); // (128, 128) /// let (width, height) = parse_sizes("x128x128"); // (128, 128)
/// let (width, height) = parse_sizes("32"); // (0, 0) /// let (width, height) = parse_sizes("32"); // (0, 0)
/// ``` /// ```
fn parse_sizes(sizes: Option<&str>) -> (u16, u16) { fn parse_sizes(sizes: &str) -> (u16, u16) {
let mut width: u16 = 0; let mut width: u16 = 0;
let mut height: u16 = 0; let mut height: u16 = 0;
if let Some(sizes) = sizes { if !sizes.is_empty() {
match ICON_SIZE_REGEX.captures(sizes.trim()) { match ICON_SIZE_REGEX.captures(sizes.trim()) {
None => {} None => {}
Some(dimensions) => { Some(dimensions) => {
@ -706,14 +656,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
(width, height) (width, height)
} }
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> { async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
if is_domain_blacklisted(domain) { if is_domain_blacklisted(domain).await {
err_silent!("Domain is blacklisted", domain) err_silent!("Domain is blacklisted", domain)
} }
let icon_result = get_icon_url(domain)?; let icon_result = get_icon_url(domain).await?;
let mut buffer = Vec::new(); let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None; let mut icon_type: Option<&str> = None;
use data_url::DataUrl; use data_url::DataUrl;
@ -722,8 +672,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
if icon.href.starts_with("data:image") { if icon.href.starts_with("data:image") {
let datauri = DataUrl::process(&icon.href).unwrap(); let datauri = DataUrl::process(&icon.href).unwrap();
// Check if we are able to decode the data uri // Check if we are able to decode the data uri
match datauri.decode_to_vec() { let mut body = BytesMut::new();
Ok((body, _fragment)) => { match datauri.decode::<_, ()>(|bytes| {
body.extend_from_slice(bytes);
Ok(())
}) {
Ok(_) => {
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create // Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
if body.len() >= 67 { if body.len() >= 67 {
// Check if the icon type is allowed, else try an icon from the list. // Check if the icon type is allowed, else try an icon from the list.
@ -733,16 +687,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
continue; continue;
} }
info!("Extracted icon from data:image uri for {}", domain); info!("Extracted icon from data:image uri for {}", domain);
buffer = body; buffer = body.freeze();
break; break;
} }
} }
_ => debug!("Extracted icon from data:image uri is invalid"), _ => debug!("Extracted icon from data:image uri is invalid"),
}; };
} else { } else {
match get_page_with_referer(&icon.href, &icon_result.referer) { match get_page_with_referer(&icon.href, &icon_result.referer).await {
Ok(mut res) => { Ok(res) => {
res.copy_to(&mut buffer)?; buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
// Check if the icon type is allowed, else try an icon from the list. // Check if the icon type is allowed, else try an icon from the list.
icon_type = get_icon_type(&buffer); icon_type = get_icon_type(&buffer);
if icon_type.is_none() { if icon_type.is_none() {
@ -765,13 +720,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
Ok((buffer, icon_type)) Ok((buffer, icon_type))
} }
fn save_icon(path: &str, icon: &[u8]) { async fn save_icon(path: &str, icon: &[u8]) {
match File::create(path) { match File::create(path).await {
Ok(mut f) => { Ok(mut f) => {
f.write_all(icon).expect("Error writing icon file"); f.write_all(icon).await.expect("Error writing icon file");
} }
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => { Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder"); create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
} }
Err(e) => { Err(e) => {
warn!("Unable to save icon: {:?}", e); warn!("Unable to save icon: {:?}", e);
@ -791,13 +746,30 @@ fn get_icon_type(bytes: &[u8]) -> Option<&'static str> {
} }
} }
/// Minimize the amount of bytes to be parsed from a reqwest result.
/// This prevents very long parsing and memory usage.
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
let mut stream = res.bytes_stream().take(max_size);
let mut buf = BytesMut::new();
let mut size = 0;
while let Some(chunk) = stream.next().await {
let chunk = &chunk?;
size += chunk.len();
buf.extend(chunk);
if size >= max_size {
break;
}
}
Ok(buf.freeze())
}
/// This is an implementation of the default Cookie Jar from Reqwest and reqwest_cookie_store build by pfernie. /// This is an implementation of the default Cookie Jar from Reqwest and reqwest_cookie_store build by pfernie.
/// The default cookie jar used by Reqwest keeps all the cookies based upon the Max-Age or Expires which could be a long time. /// The default cookie jar used by Reqwest keeps all the cookies based upon the Max-Age or Expires which could be a long time.
/// That could be used for tracking, to prevent this we force the lifespan of the cookies to always be max two minutes. /// That could be used for tracking, to prevent this we force the lifespan of the cookies to always be max two minutes.
/// A Cookie Jar is needed because some sites force a redirect with cookies to verify if a request uses cookies or not. /// A Cookie Jar is needed because some sites force a redirect with cookies to verify if a request uses cookies or not.
use cookie_store::CookieStore; use cookie_store::CookieStore;
#[derive(Default)] #[derive(Default)]
pub struct Jar(RwLock<CookieStore>); pub struct Jar(std::sync::RwLock<CookieStore>);
impl reqwest::cookie::CookieStore for Jar { impl reqwest::cookie::CookieStore for Jar {
fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &header::HeaderValue>, url: &url::Url) { fn set_cookies(&self, cookie_headers: &mut dyn Iterator<Item = &header::HeaderValue>, url: &url::Url) {
@ -820,8 +792,6 @@ impl reqwest::cookie::CookieStore for Jar {
} }
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> { fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
use bytes::Bytes;
let cookie_store = self.0.read().unwrap(); let cookie_store = self.0.read().unwrap();
let s = cookie_store let s = cookie_store
.get_request_values(url) .get_request_values(url)
@ -836,3 +806,158 @@ impl reqwest::cookie::CookieStore for Jar {
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok() header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
} }
} }
/// Custom FaviconEmitter for the html5gum parser.
/// The FaviconEmitter is using an almost 1:1 copy of the DefaultEmitter with some small changes.
/// This prevents emitting tags like comments, doctype and also strings between the tags.
/// Therefor parsing the HTML content is faster.
use std::collections::{BTreeSet, VecDeque};
#[derive(Debug)]
enum FaviconToken {
StartTag(StartTag),
EndTag(EndTag),
}
#[derive(Default, Debug)]
struct FaviconEmitter {
current_token: Option<FaviconToken>,
last_start_tag: HtmlString,
current_attribute: Option<(HtmlString, HtmlString)>,
seen_attributes: BTreeSet<HtmlString>,
emitted_tokens: VecDeque<FaviconToken>,
}
impl FaviconEmitter {
fn emit_token(&mut self, token: FaviconToken) {
self.emitted_tokens.push_front(token);
}
fn flush_current_attribute(&mut self) {
if let Some((k, v)) = self.current_attribute.take() {
match self.current_token {
Some(FaviconToken::StartTag(ref mut tag)) => {
tag.attributes.entry(k).and_modify(|_| {}).or_insert(v);
}
Some(FaviconToken::EndTag(_)) => {
self.seen_attributes.insert(k);
}
_ => {
debug_assert!(false);
}
}
}
}
}
impl Emitter for FaviconEmitter {
type Token = FaviconToken;
fn set_last_start_tag(&mut self, last_start_tag: Option<&[u8]>) {
self.last_start_tag.clear();
self.last_start_tag.extend(last_start_tag.unwrap_or_default());
}
fn pop_token(&mut self) -> Option<Self::Token> {
self.emitted_tokens.pop_back()
}
fn init_start_tag(&mut self) {
self.current_token = Some(FaviconToken::StartTag(StartTag::default()));
}
fn init_end_tag(&mut self) {
self.current_token = Some(FaviconToken::EndTag(EndTag::default()));
self.seen_attributes.clear();
}
fn emit_current_tag(&mut self) -> Option<html5gum::State> {
self.flush_current_attribute();
let mut token = self.current_token.take().unwrap();
let mut emit = false;
match token {
FaviconToken::EndTag(ref mut tag) => {
// Always clean seen attributes
self.seen_attributes.clear();
// Only trigger an emit for the </head> tag.
// This is matched, and will break the for-loop.
if *tag.name == b"head" {
emit = true;
}
}
FaviconToken::StartTag(ref mut tag) => {
// Only trriger an emit for <link> and <base> tags.
// These are the only tags we want to parse.
if *tag.name == b"link" || *tag.name == b"base" {
self.set_last_start_tag(Some(&tag.name));
emit = true;
} else {
self.set_last_start_tag(None);
}
}
}
// Only emit the tags we want to parse.
if emit {
self.emit_token(token);
}
None
}
fn push_tag_name(&mut self, s: &[u8]) {
match self.current_token {
Some(
FaviconToken::StartTag(StartTag {
ref mut name,
..
})
| FaviconToken::EndTag(EndTag {
ref mut name,
..
}),
) => {
name.extend(s);
}
_ => debug_assert!(false),
}
}
fn init_attribute(&mut self) {
self.flush_current_attribute();
self.current_attribute = Some(Default::default());
}
fn push_attribute_name(&mut self, s: &[u8]) {
self.current_attribute.as_mut().unwrap().0.extend(s);
}
fn push_attribute_value(&mut self, s: &[u8]) {
self.current_attribute.as_mut().unwrap().1.extend(s);
}
fn current_is_appropriate_end_tag_token(&mut self) -> bool {
match self.current_token {
Some(FaviconToken::EndTag(ref tag)) => !self.last_start_tag.is_empty() && self.last_start_tag == tag.name,
_ => false,
}
}
// We do not want and need these parts of the HTML document
// These will be skipped and ignored during the tokenization and iteration.
fn emit_current_comment(&mut self) {}
fn emit_current_doctype(&mut self) {}
fn emit_eof(&mut self) {}
fn emit_error(&mut self, _: html5gum::Error) {}
fn emit_string(&mut self, _: &[u8]) {}
fn init_comment(&mut self) {}
fn init_doctype(&mut self) {}
fn push_comment(&mut self, _: &[u8]) {}
fn push_doctype_name(&mut self, _: &[u8]) {}
fn push_doctype_public_identifier(&mut self, _: &[u8]) {}
fn push_doctype_system_identifier(&mut self, _: &[u8]) {}
fn set_doctype_public_identifier(&mut self, _: &[u8]) {}
fn set_doctype_system_identifier(&mut self, _: &[u8]) {}
fn set_force_quirks(&mut self) {}
fn set_self_closing(&mut self) {}
}

209
src/api/identity.rs

@ -1,16 +1,17 @@
use chrono::Utc; use chrono::Utc;
use num_traits::FromPrimitive; use num_traits::FromPrimitive;
use rocket::serde::json::Json;
use rocket::{ use rocket::{
request::{Form, FormItems, FromForm}, form::{Form, FromForm},
Route, Route,
}; };
use rocket_contrib::json::Json;
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
api::{ api::{
core::accounts::{PreloginData, _prelogin},
core::two_factor::{duo, email, email::EmailTokenData, yubikey}, core::two_factor::{duo, email, email::EmailTokenData, yubikey},
ApiResult, EmptyResult, JsonResult, ApiResult, EmptyResult, JsonResult, JsonUpcase,
}, },
auth::ClientIp, auth::ClientIp,
db::{models::*, DbConn}, db::{models::*, DbConn},
@ -19,17 +20,17 @@ use crate::{
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
routes![login] routes![login, prelogin]
} }
#[post("/connect/token", data = "<data>")] #[post("/connect/token", data = "<data>")]
fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult { async fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult {
let data: ConnectData = data.into_inner(); let data: ConnectData = data.into_inner();
match data.grant_type.as_ref() { match data.grant_type.as_ref() {
"refresh_token" => { "refresh_token" => {
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?; _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, conn) _refresh_login(data, conn).await
} }
"password" => { "password" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?; _check_is_some(&data.client_id, "client_id cannot be blank")?;
@ -41,34 +42,34 @@ fn login(data: Form<ConnectData>, conn: DbConn, ip: ClientIp) -> JsonResult {
_check_is_some(&data.device_name, "device_name cannot be blank")?; _check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?; _check_is_some(&data.device_type, "device_type cannot be blank")?;
_password_login(data, conn, &ip) _password_login(data, conn, &ip).await
} }
"client_credentials" => { "client_credentials" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?; _check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.client_secret, "client_secret cannot be blank")?; _check_is_some(&data.client_secret, "client_secret cannot be blank")?;
_check_is_some(&data.scope, "scope cannot be blank")?; _check_is_some(&data.scope, "scope cannot be blank")?;
_api_key_login(data, conn, &ip) _api_key_login(data, conn, &ip).await
} }
t => err!("Invalid type", t), t => err!("Invalid type", t),
} }
} }
fn _refresh_login(data: ConnectData, conn: DbConn) -> JsonResult { async fn _refresh_login(data: ConnectData, conn: DbConn) -> JsonResult {
// Extract token // Extract token
let token = data.refresh_token.unwrap(); let token = data.refresh_token.unwrap();
// Get device by refresh token // Get device by refresh token
let mut device = Device::find_by_refresh_token(&token, &conn).map_res("Invalid refresh token")?; let mut device = Device::find_by_refresh_token(&token, &conn).await.map_res("Invalid refresh token")?;
let scope = "api offline_access"; let scope = "api offline_access";
let scope_vec = vec!["api".into(), "offline_access".into()]; let scope_vec = vec!["api".into(), "offline_access".into()];
// Common // Common
let user = User::find_by_uuid(&device.user_uuid, &conn).unwrap(); let user = User::find_by_uuid(&device.user_uuid, &conn).await.unwrap();
let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn); let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn).await;
let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec); let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec);
device.save(&conn)?; device.save(&conn).await?;
Ok(Json(json!({ Ok(Json(json!({
"access_token": access_token, "access_token": access_token,
@ -86,7 +87,7 @@ fn _refresh_login(data: ConnectData, conn: DbConn) -> JsonResult {
}))) })))
} }
fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult { async fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult {
// Validate scope // Validate scope
let scope = data.scope.as_ref().unwrap(); let scope = data.scope.as_ref().unwrap();
if scope != "api offline_access" { if scope != "api offline_access" {
@ -98,8 +99,8 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
crate::ratelimit::check_limit_login(&ip.ip)?; crate::ratelimit::check_limit_login(&ip.ip)?;
// Get the user // Get the user
let username = data.username.as_ref().unwrap(); let username = data.username.as_ref().unwrap().trim();
let user = match User::find_by_mail(username, &conn) { let user = match User::find_by_mail(username, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Username or password is incorrect. Try again", format!("IP: {}. Username: {}.", ip.ip, username)), None => err!("Username or password is incorrect. Try again", format!("IP: {}. Username: {}.", ip.ip, username)),
}; };
@ -130,11 +131,11 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
user.last_verifying_at = Some(now); user.last_verifying_at = Some(now);
user.login_verify_count += 1; user.login_verify_count += 1;
if let Err(e) = user.save(&conn) { if let Err(e) = user.save(&conn).await {
error!("Error updating user: {:#?}", e); error!("Error updating user: {:#?}", e);
} }
if let Err(e) = mail::send_verify_email(&user.email, &user.uuid) { if let Err(e) = mail::send_verify_email(&user.email, &user.uuid).await {
error!("Error auto-sending email verification email: {:#?}", e); error!("Error auto-sending email verification email: {:#?}", e);
} }
} }
@ -144,12 +145,12 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
err!("Please verify your email before trying again.", format!("IP: {}. Username: {}.", ip.ip, username)) err!("Please verify your email before trying again.", format!("IP: {}. Username: {}.", ip.ip, username))
} }
let (mut device, new_device) = get_device(&data, &conn, &user); let (mut device, new_device) = get_device(&data, &conn, &user).await;
let twofactor_token = twofactor_auth(&user.uuid, &data, &mut device, ip, &conn)?; let twofactor_token = twofactor_auth(&user.uuid, &data, &mut device, ip, &conn).await?;
if CONFIG.mail_enabled() && new_device { if CONFIG.mail_enabled() && new_device {
if let Err(e) = mail::send_new_device_logged_in(&user.email, &ip.ip.to_string(), &now, &device.name) { if let Err(e) = mail::send_new_device_logged_in(&user.email, &ip.ip.to_string(), &now, &device.name).await {
error!("Error sending new device email: {:#?}", e); error!("Error sending new device email: {:#?}", e);
if CONFIG.require_device_email() { if CONFIG.require_device_email() {
@ -159,9 +160,9 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
} }
// Common // Common
let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn); let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn).await;
let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec); let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec);
device.save(&conn)?; device.save(&conn).await?;
let mut result = json!({ let mut result = json!({
"access_token": access_token, "access_token": access_token,
@ -187,7 +188,7 @@ fn _password_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
Ok(Json(result)) Ok(Json(result))
} }
fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult { async fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult {
// Validate scope // Validate scope
let scope = data.scope.as_ref().unwrap(); let scope = data.scope.as_ref().unwrap();
if scope != "api" { if scope != "api" {
@ -204,7 +205,7 @@ fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
Some(uuid) => uuid, Some(uuid) => uuid,
None => err!("Malformed client_id", format!("IP: {}.", ip.ip)), None => err!("Malformed client_id", format!("IP: {}.", ip.ip)),
}; };
let user = match User::find_by_uuid(user_uuid, &conn) { let user = match User::find_by_uuid(user_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err!("Invalid client_id", format!("IP: {}.", ip.ip)), None => err!("Invalid client_id", format!("IP: {}.", ip.ip)),
}; };
@ -220,11 +221,11 @@ fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
err!("Incorrect client_secret", format!("IP: {}. Username: {}.", ip.ip, user.email)) err!("Incorrect client_secret", format!("IP: {}. Username: {}.", ip.ip, user.email))
} }
let (mut device, new_device) = get_device(&data, &conn, &user); let (mut device, new_device) = get_device(&data, &conn, &user).await;
if CONFIG.mail_enabled() && new_device { if CONFIG.mail_enabled() && new_device {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
if let Err(e) = mail::send_new_device_logged_in(&user.email, &ip.ip.to_string(), &now, &device.name) { if let Err(e) = mail::send_new_device_logged_in(&user.email, &ip.ip.to_string(), &now, &device.name).await {
error!("Error sending new device email: {:#?}", e); error!("Error sending new device email: {:#?}", e);
if CONFIG.require_device_email() { if CONFIG.require_device_email() {
@ -234,9 +235,9 @@ fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
} }
// Common // Common
let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn); let orgs = UserOrganization::find_confirmed_by_user(&user.uuid, &conn).await;
let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec); let (access_token, expires_in) = device.refresh_tokens(&user, orgs, scope_vec);
device.save(&conn)?; device.save(&conn).await?;
info!("User {} logged in successfully via API key. IP: {}", user.email, ip.ip); info!("User {} logged in successfully via API key. IP: {}", user.email, ip.ip);
@ -258,7 +259,7 @@ fn _api_key_login(data: ConnectData, conn: DbConn, ip: &ClientIp) -> JsonResult
} }
/// Retrieves an existing device or creates a new device from ConnectData and the User /// Retrieves an existing device or creates a new device from ConnectData and the User
fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, bool) { async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, bool) {
// On iOS, device_type sends "iOS", on others it sends a number // On iOS, device_type sends "iOS", on others it sends a number
let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(0); let device_type = util::try_parse_string(data.device_type.as_ref()).unwrap_or(0);
let device_id = data.device_identifier.clone().expect("No device id provided"); let device_id = data.device_identifier.clone().expect("No device id provided");
@ -266,17 +267,8 @@ fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, bool)
let mut new_device = false; let mut new_device = false;
// Find device or create new // Find device or create new
let device = match Device::find_by_uuid(&device_id, conn) { let device = match Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Some(device) => { Some(device) => device,
// Check if owned device, and recreate if not
if device.user_uuid != user.uuid {
info!("Device exists but is owned by another user. The old device will be discarded");
new_device = true;
Device::new(device_id, user.uuid.clone(), device_name, device_type)
} else {
device
}
}
None => { None => {
new_device = true; new_device = true;
Device::new(device_id, user.uuid.clone(), device_name, device_type) Device::new(device_id, user.uuid.clone(), device_name, device_type)
@ -286,28 +278,28 @@ fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> (Device, bool)
(device, new_device) (device, new_device)
} }
fn twofactor_auth( async fn twofactor_auth(
user_uuid: &str, user_uuid: &str,
data: &ConnectData, data: &ConnectData,
device: &mut Device, device: &mut Device,
ip: &ClientIp, ip: &ClientIp,
conn: &DbConn, conn: &DbConn,
) -> ApiResult<Option<String>> { ) -> ApiResult<Option<String>> {
let twofactors = TwoFactor::find_by_user(user_uuid, conn); let twofactors = TwoFactor::find_by_user(user_uuid, conn).await;
// No twofactor token if twofactor is disabled // No twofactor token if twofactor is disabled
if twofactors.is_empty() { if twofactors.is_empty() {
return Ok(None); return Ok(None);
} }
TwoFactorIncomplete::mark_incomplete(user_uuid, &device.uuid, &device.name, ip, conn)?; TwoFactorIncomplete::mark_incomplete(user_uuid, &device.uuid, &device.name, ip, conn).await?;
let twofactor_ids: Vec<_> = twofactors.iter().map(|tf| tf.atype).collect(); let twofactor_ids: Vec<_> = twofactors.iter().map(|tf| tf.atype).collect();
let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, asume the first one let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, asume the first one
let twofactor_code = match data.two_factor_token { let twofactor_code = match data.two_factor_token {
Some(ref code) => code, Some(ref code) => code,
None => err_json!(_json_err_twofactor(&twofactor_ids, user_uuid, conn)?, "2FA token not provided"), None => err_json!(_json_err_twofactor(&twofactor_ids, user_uuid, conn).await?, "2FA token not provided"),
}; };
let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled); let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled);
@ -320,16 +312,17 @@ fn twofactor_auth(
match TwoFactorType::from_i32(selected_id) { match TwoFactorType::from_i32(selected_id) {
Some(TwoFactorType::Authenticator) => { Some(TwoFactorType::Authenticator) => {
_tf::authenticator::validate_totp_code_str(user_uuid, twofactor_code, &selected_data?, ip, conn)? _tf::authenticator::validate_totp_code_str(user_uuid, twofactor_code, &selected_data?, ip, conn).await?
}
Some(TwoFactorType::Webauthn) => {
_tf::webauthn::validate_webauthn_login(user_uuid, twofactor_code, conn).await?
} }
Some(TwoFactorType::U2f) => _tf::u2f::validate_u2f_login(user_uuid, twofactor_code, conn)?,
Some(TwoFactorType::Webauthn) => _tf::webauthn::validate_webauthn_login(user_uuid, twofactor_code, conn)?,
Some(TwoFactorType::YubiKey) => _tf::yubikey::validate_yubikey_login(twofactor_code, &selected_data?)?, Some(TwoFactorType::YubiKey) => _tf::yubikey::validate_yubikey_login(twofactor_code, &selected_data?)?,
Some(TwoFactorType::Duo) => { Some(TwoFactorType::Duo) => {
_tf::duo::validate_duo_login(data.username.as_ref().unwrap(), twofactor_code, conn)? _tf::duo::validate_duo_login(data.username.as_ref().unwrap().trim(), twofactor_code, conn).await?
} }
Some(TwoFactorType::Email) => { Some(TwoFactorType::Email) => {
_tf::email::validate_email_code_str(user_uuid, twofactor_code, &selected_data?, conn)? _tf::email::validate_email_code_str(user_uuid, twofactor_code, &selected_data?, conn).await?
} }
Some(TwoFactorType::Remember) => { Some(TwoFactorType::Remember) => {
@ -338,14 +331,17 @@ fn twofactor_auth(
remember = 1; // Make sure we also return the token here, otherwise it will only remember the first time remember = 1; // Make sure we also return the token here, otherwise it will only remember the first time
} }
_ => { _ => {
err_json!(_json_err_twofactor(&twofactor_ids, user_uuid, conn)?, "2FA Remember token not provided") err_json!(
_json_err_twofactor(&twofactor_ids, user_uuid, conn).await?,
"2FA Remember token not provided"
)
} }
} }
} }
_ => err!("Invalid two factor provider"), _ => err!("Invalid two factor provider"),
} }
TwoFactorIncomplete::mark_complete(user_uuid, &device.uuid, conn)?; TwoFactorIncomplete::mark_complete(user_uuid, &device.uuid, conn).await?;
if !CONFIG.disable_2fa_remember() && remember == 1 { if !CONFIG.disable_2fa_remember() && remember == 1 {
Ok(Some(device.refresh_twofactor_remember())) Ok(Some(device.refresh_twofactor_remember()))
@ -359,7 +355,7 @@ fn _selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
tf.map(|t| t.data).map_res("Two factor doesn't exist") tf.map(|t| t.data).map_res("Two factor doesn't exist")
} }
fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> ApiResult<Value> { async fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> ApiResult<Value> {
use crate::api::core::two_factor; use crate::api::core::two_factor;
let mut result = json!({ let mut result = json!({
@ -375,38 +371,18 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
match TwoFactorType::from_i32(*provider) { match TwoFactorType::from_i32(*provider) {
Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ } Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ }
Some(TwoFactorType::U2f) if CONFIG.domain_set() => {
let request = two_factor::u2f::generate_u2f_login(user_uuid, conn)?;
let mut challenge_list = Vec::new();
for key in request.registered_keys {
challenge_list.push(json!({
"appId": request.app_id,
"challenge": request.challenge,
"version": key.version,
"keyHandle": key.key_handle,
}));
}
let challenge_list_str = serde_json::to_string(&challenge_list).unwrap();
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Challenges": challenge_list_str,
});
}
Some(TwoFactorType::Webauthn) if CONFIG.domain_set() => { Some(TwoFactorType::Webauthn) if CONFIG.domain_set() => {
let request = two_factor::webauthn::generate_webauthn_login(user_uuid, conn)?; let request = two_factor::webauthn::generate_webauthn_login(user_uuid, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = request.0; result["TwoFactorProviders2"][provider.to_string()] = request.0;
} }
Some(TwoFactorType::Duo) => { Some(TwoFactorType::Duo) => {
let email = match User::find_by_uuid(user_uuid, conn) { let email = match User::find_by_uuid(user_uuid, conn).await {
Some(u) => u.email, Some(u) => u.email,
None => err!("User does not exist"), None => err!("User does not exist"),
}; };
let (signature, host) = duo::generate_duo_signature(&email, conn)?; let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({ result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host, "Host": host,
@ -415,7 +391,7 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
} }
Some(tf_type @ TwoFactorType::YubiKey) => { Some(tf_type @ TwoFactorType::YubiKey) => {
let twofactor = match TwoFactor::find_by_user_and_type(user_uuid, tf_type as i32, conn) { let twofactor = match TwoFactor::find_by_user_and_type(user_uuid, tf_type as i32, conn).await {
Some(tf) => tf, Some(tf) => tf,
None => err!("No YubiKey devices registered"), None => err!("No YubiKey devices registered"),
}; };
@ -430,14 +406,14 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
Some(tf_type @ TwoFactorType::Email) => { Some(tf_type @ TwoFactorType::Email) => {
use crate::api::core::two_factor as _tf; use crate::api::core::two_factor as _tf;
let twofactor = match TwoFactor::find_by_user_and_type(user_uuid, tf_type as i32, conn) { let twofactor = match TwoFactor::find_by_user_and_type(user_uuid, tf_type as i32, conn).await {
Some(tf) => tf, Some(tf) => tf,
None => err!("No twofactor email registered"), None => err!("No twofactor email registered"),
}; };
// Send email immediately if email is the only 2FA option // Send email immediately if email is the only 2FA option
if providers.len() == 1 { if providers.len() == 1 {
_tf::email::send_token(user_uuid, conn)? _tf::email::send_token(user_uuid, conn).await?
} }
let email_data = EmailTokenData::from_json(&twofactor.data)?; let email_data = EmailTokenData::from_json(&twofactor.data)?;
@ -453,68 +429,65 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
Ok(result) Ok(result)
} }
#[post("/accounts/prelogin", data = "<data>")]
async fn prelogin(data: JsonUpcase<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
}
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts // https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs // https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default, FromForm)]
#[allow(non_snake_case)] #[allow(non_snake_case)]
struct ConnectData { struct ConnectData {
// refresh_token, password, client_credentials (API key) #[field(name = uncased("grant_type"))]
grant_type: String, #[field(name = uncased("granttype"))]
grant_type: String, // refresh_token, password, client_credentials (API key)
// Needed for grant_type="refresh_token" // Needed for grant_type="refresh_token"
#[field(name = uncased("refresh_token"))]
#[field(name = uncased("refreshtoken"))]
refresh_token: Option<String>, refresh_token: Option<String>,
// Needed for grant_type = "password" | "client_credentials" // Needed for grant_type = "password" | "client_credentials"
client_id: Option<String>, // web, cli, desktop, browser, mobile #[field(name = uncased("client_id"))]
client_secret: Option<String>, // API key login (cli only) #[field(name = uncased("clientid"))]
client_id: Option<String>, // web, cli, desktop, browser, mobile
#[field(name = uncased("client_secret"))]
#[field(name = uncased("clientsecret"))]
client_secret: Option<String>,
#[field(name = uncased("password"))]
password: Option<String>, password: Option<String>,
#[field(name = uncased("scope"))]
scope: Option<String>, scope: Option<String>,
#[field(name = uncased("username"))]
username: Option<String>, username: Option<String>,
#[field(name = uncased("device_identifier"))]
#[field(name = uncased("deviceidentifier"))]
device_identifier: Option<String>, device_identifier: Option<String>,
#[field(name = uncased("device_name"))]
#[field(name = uncased("devicename"))]
device_name: Option<String>, device_name: Option<String>,
#[field(name = uncased("device_type"))]
#[field(name = uncased("devicetype"))]
device_type: Option<String>, device_type: Option<String>,
device_push_token: Option<String>, // Unused; mobile device push not yet supported. #[allow(unused)]
#[field(name = uncased("device_push_token"))]
#[field(name = uncased("devicepushtoken"))]
_device_push_token: Option<String>, // Unused; mobile device push not yet supported.
// Needed for two-factor auth // Needed for two-factor auth
#[field(name = uncased("two_factor_provider"))]
#[field(name = uncased("twofactorprovider"))]
two_factor_provider: Option<i32>, two_factor_provider: Option<i32>,
#[field(name = uncased("two_factor_token"))]
#[field(name = uncased("twofactortoken"))]
two_factor_token: Option<String>, two_factor_token: Option<String>,
#[field(name = uncased("two_factor_remember"))]
#[field(name = uncased("twofactorremember"))]
two_factor_remember: Option<i32>, two_factor_remember: Option<i32>,
} }
impl<'f> FromForm<'f> for ConnectData {
type Error = String;
fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
let mut form = Self::default();
for item in items {
let (key, value) = item.key_value_decoded();
let mut normalized_key = key.to_lowercase();
normalized_key.retain(|c| c != '_'); // Remove '_'
match normalized_key.as_ref() {
"granttype" => form.grant_type = value,
"refreshtoken" => form.refresh_token = Some(value),
"clientid" => form.client_id = Some(value),
"clientsecret" => form.client_secret = Some(value),
"password" => form.password = Some(value),
"scope" => form.scope = Some(value),
"username" => form.username = Some(value),
"deviceidentifier" => form.device_identifier = Some(value),
"devicename" => form.device_name = Some(value),
"devicetype" => form.device_type = Some(value),
"devicepushtoken" => form.device_push_token = Some(value),
"twofactorprovider" => form.two_factor_provider = value.parse().ok(),
"twofactortoken" => form.two_factor_token = Some(value),
"twofactorremember" => form.two_factor_remember = value.parse().ok(),
key => warn!("Detected unexpected parameter during login: {}", key),
}
}
Ok(form)
}
}
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult { fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
if value.is_none() { if value.is_none() {
err!(msg) err!(msg)

2
src/api/mod.rs

@ -5,7 +5,7 @@ mod identity;
mod notifications; mod notifications;
mod web; mod web;
use rocket_contrib::json::Json; use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
pub use crate::api::{ pub use crate::api::{

391
src/api/notifications.rs

@ -1,19 +1,41 @@
use std::sync::atomic::{AtomicBool, Ordering}; use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
use rocket::Route; use chrono::NaiveDateTime;
use rocket_contrib::json::Json; use futures::{SinkExt, StreamExt};
use rmpv::Value;
use rocket::{serde::json::Json, Route};
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
use tokio::{
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG}; net::{TcpListener, TcpStream},
sync::mpsc::Sender,
};
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{handshake, Message},
};
use crate::{
api::EmptyResult,
auth::Headers,
db::models::{Cipher, Folder, Send, User},
Error, CONFIG,
};
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
routes![negotiate, websockets_err] routes![negotiate, websockets_err]
} }
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
#[get("/hub")] #[get("/hub")]
fn websockets_err() -> EmptyResult { fn websockets_err() -> EmptyResult {
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
if CONFIG.websocket_enabled() if CONFIG.websocket_enabled()
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok() && SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
{ {
@ -55,19 +77,6 @@ fn negotiate(_headers: Headers) -> Json<JsonValue> {
// //
// Websockets server // Websockets server
// //
use std::io;
use std::sync::Arc;
use std::thread;
use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender};
use chashmap::CHashMap;
use chrono::NaiveDateTime;
use serde_json::from_str;
use crate::db::models::{Cipher, Folder, Send, User};
use rmpv::Value;
fn serialize(val: Value) -> Vec<u8> { fn serialize(val: Value) -> Vec<u8> {
use rmpv::encode::write_value; use rmpv::encode::write_value;
@ -118,192 +127,49 @@ fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
} }
} }
// Server WebSocket handler
pub struct WsHandler {
out: Sender,
user_uuid: Option<String>,
users: WebSocketUsers,
}
const RECORD_SEPARATOR: u8 = 0x1e; const RECORD_SEPARATOR: u8 = 0x1e;
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS> const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
#[derive(Deserialize)] #[derive(Deserialize, Copy, Clone, Eq, PartialEq)]
struct InitialMessage { struct InitialMessage<'a> {
protocol: String, protocol: &'a str,
version: i32, version: i32,
} }
const PING_MS: u64 = 15_000; static INITIAL_MESSAGE: InitialMessage<'static> = InitialMessage {
const PING: Token = Token(1); protocol: "messagepack",
version: 1,
const ACCESS_TOKEN_KEY: &str = "access_token="; };
impl WsHandler {
fn err(&self, msg: &'static str) -> ws::Result<()> {
self.out.close(ws::CloseCode::Invalid)?;
// We need to specifically return an IO error so ws closes the connection
let io_error = io::Error::from(io::ErrorKind::InvalidData);
Err(ws::Error::new(ws::ErrorKind::Io(io_error), msg))
}
fn get_request_token(&self, hs: Handshake) -> Option<String> {
use std::str::from_utf8;
// Verify we have a token header
if let Some(header_value) = hs.request.header("Authorization") {
if let Ok(converted) = from_utf8(header_value) {
if let Some(token_part) = converted.split("Bearer ").nth(1) {
return Some(token_part.into());
}
}
};
// Otherwise verify the query parameter value
let path = hs.request.resource();
if let Some(params) = path.split('?').nth(1) {
let params_iter = params.split('&').take(1);
for val in params_iter {
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
return Some(stripped.into());
}
}
};
None
}
}
impl Handler for WsHandler {
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
// Path == "/notifications/hub?id=<id>==&access_token=<access_token>"
//
// We don't use `id`, and as of around 2020-03-25, the official clients
// no longer seem to pass `id` (only `access_token`).
// Get user token from header or query parameter
let access_token = match self.get_request_token(hs) {
Some(token) => token,
_ => return self.err("Missing access token"),
};
// Validate the user
use crate::auth;
let claims = match auth::decode_login(access_token.as_str()) {
Ok(claims) => claims,
Err(_) => return self.err("Invalid access token provided"),
};
// Assign the user to the handler
let user_uuid = claims.sub;
self.user_uuid = Some(user_uuid.clone());
// Add the current Sender to the user list
let handler_insert = self.out.clone();
let handler_update = self.out.clone();
self.users.map.upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update));
// Schedule a ping to keep the connection alive
self.out.timeout(PING_MS, PING)
}
fn on_message(&mut self, msg: Message) -> ws::Result<()> {
if let Message::Text(text) = msg.clone() {
let json = &text[..text.len() - 1]; // Remove last char
if let Ok(InitialMessage {
protocol,
version,
}) = from_str::<InitialMessage>(json)
{
if &protocol == "messagepack" && version == 1 {
return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message
}
}
}
// If it's not the initial message, just echo the message
self.out.send(msg)
}
fn on_timeout(&mut self, event: Token) -> ws::Result<()> {
if event == PING {
// send ping
self.out.send(create_ping())?;
// reschedule the timeout
self.out.timeout(PING_MS, PING)
} else {
Ok(())
}
}
}
struct WsFactory {
pub users: WebSocketUsers,
}
impl WsFactory {
pub fn init() -> Self {
WsFactory {
users: WebSocketUsers {
map: Arc::new(CHashMap::new()),
},
}
}
}
impl Factory for WsFactory {
type Handler = WsHandler;
fn connection_made(&mut self, out: Sender) -> Self::Handler {
WsHandler {
out,
user_uuid: None,
users: self.users.clone(),
}
}
fn connection_lost(&mut self, handler: Self::Handler) {
// Remove handler
if let Some(user_uuid) = &handler.user_uuid {
if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) {
if let Some(pos) = user_conn.iter().position(|x| x == &handler.out) {
user_conn.remove(pos);
}
}
}
}
}
// We attach the UUID to the sender so we can differentiate them when we need to remove them from the Vec
type UserSenders = (uuid::Uuid, Sender<Message>);
#[derive(Clone)] #[derive(Clone)]
pub struct WebSocketUsers { pub struct WebSocketUsers {
map: Arc<CHashMap<String, Vec<Sender>>>, map: Arc<dashmap::DashMap<String, Vec<UserSenders>>>,
} }
impl WebSocketUsers { impl WebSocketUsers {
fn send_update(&self, user_uuid: &str, data: &[u8]) -> ws::Result<()> { async fn send_update(&self, user_uuid: &str, data: &[u8]) {
if let Some(user) = self.map.get(user_uuid) { if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
for sender in user.iter() { for (_, sender) in user.iter() {
sender.send(data)?; if sender.send(Message::binary(data)).await.is_err() {
// TODO: Delete from map here too?
}
} }
} }
Ok(())
} }
// NOTE: The last modified date needs to be updated before calling these methods // NOTE: The last modified date needs to be updated before calling these methods
pub fn send_user_update(&self, ut: UpdateType, user: &User) { pub async fn send_user_update(&self, ut: UpdateType, user: &User) {
let data = create_update( let data = create_update(
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))], vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
ut, ut,
); );
self.send_update(&user.uuid, &data).ok(); self.send_update(&user.uuid, &data).await;
} }
pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) { pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
let data = create_update( let data = create_update(
vec![ vec![
("Id".into(), folder.uuid.clone().into()), ("Id".into(), folder.uuid.clone().into()),
@ -313,10 +179,10 @@ impl WebSocketUsers {
ut, ut,
); );
self.send_update(&folder.user_uuid, &data).ok(); self.send_update(&folder.user_uuid, &data).await;
} }
pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) { pub async fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) {
let user_uuid = convert_option(cipher.user_uuid.clone()); let user_uuid = convert_option(cipher.user_uuid.clone());
let org_uuid = convert_option(cipher.organization_uuid.clone()); let org_uuid = convert_option(cipher.organization_uuid.clone());
@ -332,11 +198,11 @@ impl WebSocketUsers {
); );
for uuid in user_uuids { for uuid in user_uuids {
self.send_update(uuid, &data).ok(); self.send_update(uuid, &data).await;
} }
} }
pub fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) { pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
let user_uuid = convert_option(send.user_uuid.clone()); let user_uuid = convert_option(send.user_uuid.clone());
let data = create_update( let data = create_update(
@ -349,7 +215,7 @@ impl WebSocketUsers {
); );
for uuid in user_uuids { for uuid in user_uuids {
self.send_update(uuid, &data).ok(); self.send_update(uuid, &data).await;
} }
} }
} }
@ -392,7 +258,7 @@ fn create_ping() -> Vec<u8> {
} }
#[allow(dead_code)] #[allow(dead_code)]
#[derive(PartialEq)] #[derive(Eq, PartialEq)]
pub enum UpdateType { pub enum UpdateType {
CipherUpdate = 0, CipherUpdate = 0,
CipherCreate = 1, CipherCreate = 1,
@ -416,28 +282,145 @@ pub enum UpdateType {
None = 100, None = 100,
} }
use rocket::State; pub type Notify<'a> = &'a rocket::State<WebSocketUsers>;
pub type Notify<'a> = State<'a, WebSocketUsers>;
pub fn start_notification_server() -> WebSocketUsers { pub fn start_notification_server() -> WebSocketUsers {
let factory = WsFactory::init(); let users = WebSocketUsers {
let users = factory.users.clone(); map: Arc::new(dashmap::DashMap::new()),
};
if CONFIG.websocket_enabled() { if CONFIG.websocket_enabled() {
thread::spawn(move || { let users2 = users.clone();
let mut settings = ws::Settings::default(); tokio::spawn(async move {
settings.max_connections = 500; let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
settings.queue_size = 2; info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
settings.panic_on_internal = false; let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port");
ws::Builder::new() let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
.with_settings(settings) CONFIG.set_ws_shutdown_handle(shutdown_tx);
.build(factory)
.unwrap() loop {
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())) tokio::select! {
.unwrap(); Ok((stream, addr)) = listener.accept() => {
tokio::spawn(handle_connection(stream, users2.clone(), addr));
}
_ = &mut shutdown_rx => {
break;
}
}
}
info!("Shutting down WebSockets server!")
}); });
} }
users users
} }
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, addr: SocketAddr) -> Result<(), Error> {
let mut user_uuid: Option<String> = None;
info!("Accepting WS connection from {addr}");
// Accept connection, do initial handshake, validate auth token and get the user ID
use handshake::server::{Request, Response};
let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| {
if let Some(token) = get_request_token(req) {
if let Ok(claims) = crate::auth::decode_login(&token) {
user_uuid = Some(claims.sub);
return Ok(res);
}
}
Err(Response::builder().status(401).body(None).unwrap())
})
.await?;
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
// Add a channel to send messages to this client to the map
let entry_uuid = uuid::Uuid::new_v4();
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
let mut interval = tokio::time::interval(Duration::from_secs(15));
loop {
tokio::select! {
res = stream.next() => {
match res {
Some(Ok(message)) => {
// Respond to any pings
if let Message::Ping(ping) = message {
if stream.send(Message::Pong(ping)).await.is_err() {
break;
}
continue;
} else if let Message::Pong(_) = message {
/* Ignored */
continue;
}
// We should receive an initial message with the protocol and version, and we will reply to it
if let Message::Text(ref message) = message {
let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
continue;
}
}
// Just echo anything else the client sends
if stream.send(message).await.is_err() {
break;
}
}
_ => break,
}
}
res = rx.recv() => {
match res {
Some(res) => {
if stream.send(res).await.is_err() {
break;
}
},
None => break,
}
}
_= interval.tick() => {
if stream.send(Message::Ping(create_ping())).await.is_err() {
break;
}
}
}
}
info!("Closing WS connection from {addr}");
// Delete from map
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid);
Ok(())
}
fn get_request_token(req: &handshake::server::Request) -> Option<String> {
const ACCESS_TOKEN_KEY: &str = "access_token=";
if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) {
if let Some(token_part) = auth.strip_prefix("Bearer ") {
return Some(token_part.to_owned());
}
}
if let Some(params) = req.uri().query() {
let params_iter = params.split('&').take(1);
for val in params_iter {
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
return Some(stripped.to_owned());
}
}
}
None
}

55
src/api/web.rs

@ -1,10 +1,11 @@
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route}; use rocket::serde::json::Json;
use rocket_contrib::json::Json; use rocket::{fs::NamedFile, http::ContentType, Route};
use serde_json::Value; use serde_json::Value;
use crate::{ use crate::{
api::core::now,
error::Error, error::Error,
util::{Cached, SafeString}, util::{Cached, SafeString},
CONFIG, CONFIG,
@ -21,16 +22,16 @@ pub fn routes() -> Vec<Route> {
} }
#[get("/")] #[get("/")]
fn web_index() -> Cached<Option<NamedFile>> { async fn web_index() -> Cached<Option<NamedFile>> {
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false) Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
} }
#[get("/app-id.json")] #[get("/app-id.json")]
fn app_id() -> Cached<Content<Json<Value>>> { fn app_id() -> Cached<(ContentType, Json<Value>)> {
let content_type = ContentType::new("application", "fido.trusted-apps+json"); let content_type = ContentType::new("application", "fido.trusted-apps+json");
Cached::long( Cached::long(
Content( (
content_type, content_type,
Json(json!({ Json(json!({
"trustedFacets": [ "trustedFacets": [
@ -58,45 +59,37 @@ fn app_id() -> Cached<Content<Json<Value>>> {
} }
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match #[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> { async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true) Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
} }
#[get("/attachments/<uuid>/<file_id>")] #[get("/attachments/<uuid>/<file_id>")]
fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> { async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok() NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
} }
// We use DbConn here to let the alive healthcheck also verify the database connection. // We use DbConn here to let the alive healthcheck also verify the database connection.
use crate::db::DbConn; use crate::db::DbConn;
#[get("/alive")] #[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> { fn alive(_conn: DbConn) -> Json<String> {
use crate::util::format_date; now()
use chrono::Utc;
Json(format_date(&Utc::now().naive_utc()))
} }
#[get("/vw_static/<filename>")] #[get("/vw_static/<filename>")]
fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> { fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
match filename.as_ref() { match filename.as_ref() {
"mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))), "mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
"logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))), "logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
"error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))), "error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
"hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))), "hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
"vaultwarden-icon.png" => { "vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))) "bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
} "bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
"identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
"bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))), "datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
"bootstrap-native.js" => { "datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
}
"identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
"datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
"datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
"jquery-3.6.0.slim.js" => { "jquery-3.6.0.slim.js" => {
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js"))) Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
} }
_ => err!(format!("Static file not found: {}", filename)), _ => err!(format!("Static file not found: {}", filename)),
} }

297
src/auth.rs

@ -11,7 +11,6 @@ use serde::ser::Serialize;
use crate::{ use crate::{
error::{Error, MapResult}, error::{Error, MapResult},
util::read_file,
CONFIG, CONFIG,
}; };
@ -30,13 +29,13 @@ static JWT_ADMIN_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|admin", CONFIG.
static JWT_SEND_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|send", CONFIG.domain_origin())); static JWT_SEND_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|send", CONFIG.domain_origin()));
static PRIVATE_RSA_KEY_VEC: Lazy<Vec<u8>> = Lazy::new(|| { static PRIVATE_RSA_KEY_VEC: Lazy<Vec<u8>> = Lazy::new(|| {
read_file(&CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key.\n{}", e)) std::fs::read(&CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key.\n{}", e))
}); });
static PRIVATE_RSA_KEY: Lazy<EncodingKey> = Lazy::new(|| { static PRIVATE_RSA_KEY: Lazy<EncodingKey> = Lazy::new(|| {
EncodingKey::from_rsa_pem(&PRIVATE_RSA_KEY_VEC).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{}", e)) EncodingKey::from_rsa_pem(&PRIVATE_RSA_KEY_VEC).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{}", e))
}); });
static PUBLIC_RSA_KEY_VEC: Lazy<Vec<u8>> = Lazy::new(|| { static PUBLIC_RSA_KEY_VEC: Lazy<Vec<u8>> = Lazy::new(|| {
read_file(&CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key.\n{}", e)) std::fs::read(&CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key.\n{}", e))
}); });
static PUBLIC_RSA_KEY: Lazy<DecodingKey> = Lazy::new(|| { static PUBLIC_RSA_KEY: Lazy<DecodingKey> = Lazy::new(|| {
DecodingKey::from_rsa_pem(&PUBLIC_RSA_KEY_VEC).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{}", e)) DecodingKey::from_rsa_pem(&PUBLIC_RSA_KEY_VEC).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{}", e))
@ -55,15 +54,11 @@ pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
} }
fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> { fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Error> {
let validation = jsonwebtoken::Validation { let mut validation = jsonwebtoken::Validation::new(JWT_ALGORITHM);
leeway: 30, // 30 seconds validation.leeway = 30; // 30 seconds
validate_exp: true, validation.validate_exp = true;
validate_nbf: true, validation.validate_nbf = true;
aud: None, validation.set_issuer(&[issuer]);
iss: Some(issuer),
sub: None,
algorithms: vec![JWT_ALGORITHM],
};
let token = token.replace(char::is_whitespace, ""); let token = token.replace(char::is_whitespace, "");
jsonwebtoken::decode(&token, &PUBLIC_RSA_KEY, &validation).map(|d| d.claims).map_res("Error decoding JWT") jsonwebtoken::decode(&token, &PUBLIC_RSA_KEY, &validation).map(|d| d.claims).map_res("Error decoding JWT")
@ -257,7 +252,10 @@ pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims {
// //
// Bearer token authentication // Bearer token authentication
// //
use rocket::request::{FromRequest, Outcome, Request}; use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::db::{ use crate::db::{
models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException}, models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
@ -268,10 +266,11 @@ pub struct Host {
pub host: String, pub host: String,
} }
impl<'a, 'r> FromRequest<'a, 'r> for Host { #[rocket::async_trait]
impl<'r> FromRequest<'r> for Host {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers(); let headers = request.headers();
// Get host // Get host
@ -314,17 +313,14 @@ pub struct Headers {
pub user: User, pub user: User,
} }
impl<'a, 'r> FromRequest<'a, 'r> for Headers { #[rocket::async_trait]
impl<'r> FromRequest<'r> for Headers {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers(); let headers = request.headers();
let host = match Host::from_request(request) { let host = try_outcome!(Host::from_request(request).await).host;
Outcome::Forward(_) => return Outcome::Forward(()),
Outcome::Failure(f) => return Outcome::Failure(f),
Outcome::Success(host) => host.host,
};
// Get access_token // Get access_token
let access_token: &str = match headers.get_one("Authorization") { let access_token: &str = match headers.get_one("Authorization") {
@ -344,17 +340,17 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
let device_uuid = claims.device; let device_uuid = claims.device;
let user_uuid = claims.sub; let user_uuid = claims.sub;
let conn = match request.guard::<DbConn>() { let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn, Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), _ => err_handler!("Error getting DB"),
}; };
let device = match Device::find_by_uuid(&device_uuid, &conn) { let device = match Device::find_by_uuid_and_user(&device_uuid, &user_uuid, &conn).await {
Some(device) => device, Some(device) => device,
None => err_handler!("Invalid device id"), None => err_handler!("Invalid device id"),
}; };
let user = match User::find_by_uuid(&user_uuid, &conn) { let user = match User::find_by_uuid(&user_uuid, &conn).await {
Some(user) => user, Some(user) => user,
None => err_handler!("Device has no user associated"), None => err_handler!("Device has no user associated"),
}; };
@ -363,7 +359,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
if let Some(stamp_exception) = if let Some(stamp_exception) =
user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok()) user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
{ {
let current_route = match request.route().and_then(|r| r.name) { let current_route = match request.route().and_then(|r| r.name.as_deref()) {
Some(name) => name, Some(name) => name,
_ => err_handler!("Error getting current route for stamp exception"), _ => err_handler!("Error getting current route for stamp exception"),
}; };
@ -376,7 +372,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
// This prevents checking this stamp exception for new requests. // This prevents checking this stamp exception for new requests.
let mut user = user; let mut user = user;
user.reset_stamp_exception(); user.reset_stamp_exception();
if let Err(e) = user.save(&conn) { if let Err(e) = user.save(&conn).await {
error!("Error updating user: {:#?}", e); error!("Error updating user: {:#?}", e);
} }
err_handler!("Stamp exception is expired") err_handler!("Stamp exception is expired")
@ -410,14 +406,14 @@ pub struct OrgHeaders {
// org_id is usually the second path param ("/organizations/<org_id>"), // org_id is usually the second path param ("/organizations/<org_id>"),
// but there are cases where it is a query value. // but there are cases where it is a query value.
// First check the path, if this is not a valid uuid, try the query values. // First check the path, if this is not a valid uuid, try the query values.
fn get_org_id(request: &Request) -> Option<String> { fn get_org_id(request: &Request<'_>) -> Option<String> {
if let Some(Ok(org_id)) = request.get_param::<String>(1) { if let Some(Ok(org_id)) = request.param::<String>(1) {
if uuid::Uuid::parse_str(&org_id).is_ok() { if uuid::Uuid::parse_str(&org_id).is_ok() {
return Some(org_id); return Some(org_id);
} }
} }
if let Some(Ok(org_id)) = request.get_query_value::<String>("organizationId") { if let Some(Ok(org_id)) = request.query_value::<String>("organizationId") {
if uuid::Uuid::parse_str(&org_id).is_ok() { if uuid::Uuid::parse_str(&org_id).is_ok() {
return Some(org_id); return Some(org_id);
} }
@ -426,52 +422,48 @@ fn get_org_id(request: &Request) -> Option<String> {
None None
} }
impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders { #[rocket::async_trait]
impl<'r> FromRequest<'r> for OrgHeaders {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<Headers>() { let headers = try_outcome!(Headers::from_request(request).await);
Outcome::Forward(_) => Outcome::Forward(()), match get_org_id(request) {
Outcome::Failure(f) => Outcome::Failure(f), Some(org_id) => {
Outcome::Success(headers) => { let conn = match DbConn::from_request(request).await {
match get_org_id(request) { Outcome::Success(conn) => conn,
Some(org_id) => { _ => err_handler!("Error getting DB"),
let conn = match request.guard::<DbConn>() { };
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), let user = headers.user;
}; let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn).await {
Some(user) => {
let user = headers.user; if user.status == UserOrgStatus::Confirmed as i32 {
let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) { user
Some(user) => { } else {
if user.status == UserOrgStatus::Confirmed as i32 { err_handler!("The current user isn't confirmed member of the organization")
user }
} else {
err_handler!("The current user isn't confirmed member of the organization")
}
}
None => err_handler!("The current user isn't member of the organization"),
};
Outcome::Success(Self {
host: headers.host,
device: headers.device,
user,
org_user_type: {
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
org_usr_type
} else {
// This should only happen if the DB is corrupted
err_handler!("Unknown user type in the database")
}
},
org_user,
org_id,
})
} }
_ => err_handler!("Error getting the organization id"), None => err_handler!("The current user isn't member of the organization"),
} };
Outcome::Success(Self {
host: headers.host,
device: headers.device,
user,
org_user_type: {
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
org_usr_type
} else {
// This should only happen if the DB is corrupted
err_handler!("Unknown user type in the database")
}
},
org_user,
org_id,
})
} }
_ => err_handler!("Error getting the organization id"),
} }
} }
} }
@ -483,25 +475,21 @@ pub struct AdminHeaders {
pub org_user_type: UserOrgType, pub org_user_type: UserOrgType,
} }
impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders { #[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminHeaders {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<OrgHeaders>() { let headers = try_outcome!(OrgHeaders::from_request(request).await);
Outcome::Forward(_) => Outcome::Forward(()), if headers.org_user_type >= UserOrgType::Admin {
Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(Self {
Outcome::Success(headers) => { host: headers.host,
if headers.org_user_type >= UserOrgType::Admin { device: headers.device,
Outcome::Success(Self { user: headers.user,
host: headers.host, org_user_type: headers.org_user_type,
device: headers.device, })
user: headers.user, } else {
org_user_type: headers.org_user_type, err_handler!("You need to be Admin or Owner to call this endpoint")
})
} else {
err_handler!("You need to be Admin or Owner to call this endpoint")
}
}
} }
} }
} }
@ -519,14 +507,14 @@ impl From<AdminHeaders> for Headers {
// col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"), // col_id is usually the fourth path param ("/organizations/<org_id>/collections/<col_id>"),
// but there could be cases where it is a query value. // but there could be cases where it is a query value.
// First check the path, if this is not a valid uuid, try the query values. // First check the path, if this is not a valid uuid, try the query values.
fn get_col_id(request: &Request) -> Option<String> { fn get_col_id(request: &Request<'_>) -> Option<String> {
if let Some(Ok(col_id)) = request.get_param::<String>(3) { if let Some(Ok(col_id)) = request.param::<String>(3) {
if uuid::Uuid::parse_str(&col_id).is_ok() { if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id); return Some(col_id);
} }
} }
if let Some(Ok(col_id)) = request.get_query_value::<String>("collectionId") { if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
if uuid::Uuid::parse_str(&col_id).is_ok() { if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id); return Some(col_id);
} }
@ -545,46 +533,40 @@ pub struct ManagerHeaders {
pub org_user_type: UserOrgType, pub org_user_type: UserOrgType,
} }
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeaders { #[rocket::async_trait]
impl<'r> FromRequest<'r> for ManagerHeaders {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<OrgHeaders>() { let headers = try_outcome!(OrgHeaders::from_request(request).await);
Outcome::Forward(_) => Outcome::Forward(()), if headers.org_user_type >= UserOrgType::Manager {
Outcome::Failure(f) => Outcome::Failure(f), match get_col_id(request) {
Outcome::Success(headers) => { Some(col_id) => {
if headers.org_user_type >= UserOrgType::Manager { let conn = match DbConn::from_request(request).await {
match get_col_id(request) { Outcome::Success(conn) => conn,
Some(col_id) => { _ => err_handler!("Error getting DB"),
let conn = match request.guard::<DbConn>() { };
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"), if !headers.org_user.has_full_access() {
}; match CollectionUser::find_by_collection_and_user(&col_id, &headers.org_user.user_uuid, &conn)
.await
if !headers.org_user.has_full_access() { {
match CollectionUser::find_by_collection_and_user( Some(_) => (),
&col_id, None => err_handler!("The current user isn't a manager for this collection"),
&headers.org_user.user_uuid,
&conn,
) {
Some(_) => (),
None => err_handler!("The current user isn't a manager for this collection"),
}
}
} }
_ => err_handler!("Error getting the collection id"),
} }
Outcome::Success(Self {
host: headers.host,
device: headers.device,
user: headers.user,
org_user_type: headers.org_user_type,
})
} else {
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
} }
_ => err_handler!("Error getting the collection id"),
} }
Outcome::Success(Self {
host: headers.host,
device: headers.device,
user: headers.user,
org_user_type: headers.org_user_type,
})
} else {
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
} }
} }
} }
@ -608,25 +590,21 @@ pub struct ManagerHeadersLoose {
pub org_user_type: UserOrgType, pub org_user_type: UserOrgType,
} }
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeadersLoose { #[rocket::async_trait]
impl<'r> FromRequest<'r> for ManagerHeadersLoose {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<OrgHeaders>() { let headers = try_outcome!(OrgHeaders::from_request(request).await);
Outcome::Forward(_) => Outcome::Forward(()), if headers.org_user_type >= UserOrgType::Manager {
Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(Self {
Outcome::Success(headers) => { host: headers.host,
if headers.org_user_type >= UserOrgType::Manager { device: headers.device,
Outcome::Success(Self { user: headers.user,
host: headers.host, org_user_type: headers.org_user_type,
device: headers.device, })
user: headers.user, } else {
org_user_type: headers.org_user_type, err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
})
} else {
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
}
}
} }
} }
} }
@ -647,24 +625,20 @@ pub struct OwnerHeaders {
pub user: User, pub user: User,
} }
impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders { #[rocket::async_trait]
impl<'r> FromRequest<'r> for OwnerHeaders {
type Error = &'static str; type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.guard::<OrgHeaders>() { let headers = try_outcome!(OrgHeaders::from_request(request).await);
Outcome::Forward(_) => Outcome::Forward(()), if headers.org_user_type == UserOrgType::Owner {
Outcome::Failure(f) => Outcome::Failure(f), Outcome::Success(Self {
Outcome::Success(headers) => { host: headers.host,
if headers.org_user_type == UserOrgType::Owner { device: headers.device,
Outcome::Success(Self { user: headers.user,
host: headers.host, })
device: headers.device, } else {
user: headers.user, err_handler!("You need to be Owner to call this endpoint")
})
} else {
err_handler!("You need to be Owner to call this endpoint")
}
}
} }
} }
} }
@ -678,10 +652,11 @@ pub struct ClientIp {
pub ip: IpAddr, pub ip: IpAddr,
} }
impl<'a, 'r> FromRequest<'a, 'r> for ClientIp { #[rocket::async_trait]
impl<'r> FromRequest<'r> for ClientIp {
type Error = (); type Error = ();
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> { async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = if CONFIG._ip_header_enabled() { let ip = if CONFIG._ip_header_enabled() {
req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| { req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
match ip.find(',') { match ip.find(',') {

126
src/config.rs

@ -36,6 +36,9 @@ macro_rules! make_config {
pub struct Config { inner: RwLock<Inner> } pub struct Config { inner: RwLock<Inner> }
struct Inner { struct Inner {
rocket_shutdown_handle: Option<rocket::Shutdown>,
ws_shutdown_handle: Option<tokio::sync::oneshot::Sender<()>>,
templates: Handlebars<'static>, templates: Handlebars<'static>,
config: ConfigItems, config: ConfigItems,
@ -56,13 +59,13 @@ macro_rules! make_config {
impl ConfigBuilder { impl ConfigBuilder {
#[allow(clippy::field_reassign_with_default)] #[allow(clippy::field_reassign_with_default)]
fn from_env() -> Self { fn from_env() -> Self {
match dotenv::from_path(".env") { match dotenvy::from_path(get_env("ENV_FILE").unwrap_or_else(|| String::from(".env"))) {
Ok(_) => (), Ok(_) => (),
Err(e) => match e { Err(e) => match e {
dotenv::Error::LineParse(msg, pos) => { dotenvy::Error::LineParse(msg, pos) => {
panic!("Error loading the .env file:\nNear {:?} on position {}\nPlease fix and restart!\n", msg, pos); panic!("Error loading the .env file:\nNear {:?} on position {}\nPlease fix and restart!\n", msg, pos);
}, },
dotenv::Error::Io(ioerr) => match ioerr.kind() { dotenvy::Error::Io(ioerr) => match ioerr.kind() {
std::io::ErrorKind::NotFound => { std::io::ErrorKind::NotFound => {
println!("[INFO] No .env file found.\n"); println!("[INFO] No .env file found.\n");
}, },
@ -88,8 +91,7 @@ macro_rules! make_config {
} }
fn from_file(path: &str) -> Result<Self, Error> { fn from_file(path: &str) -> Result<Self, Error> {
use crate::util::read_file_string; let config_str = std::fs::read_to_string(path)?;
let config_str = read_file_string(path)?;
serde_json::from_str(&config_str).map_err(Into::into) serde_json::from_str(&config_str).map_err(Into::into)
} }
@ -332,6 +334,8 @@ make_config! {
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments"); attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
/// Sends folder /// Sends folder
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends"); sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
/// Temp folder |> Used for storing temporary file uploads
tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp");
/// Templates folder /// Templates folder
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates"); templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
/// Session JWT key /// Session JWT key
@ -431,6 +435,8 @@ make_config! {
/// Password iterations |> Number of server-side passwords hashing iterations. /// Password iterations |> Number of server-side passwords hashing iterations.
/// The changes only apply when a user changes their password. Not recommended to lower the value /// The changes only apply when a user changes their password. Not recommended to lower the value
password_iterations: i32, true, def, 100_000; password_iterations: i32, true, def, 100_000;
/// Allow password hints |> Controls whether users can set password hints. This setting applies globally to all users.
password_hints_allowed: bool, true, def, true;
/// Show password hint |> Controls whether a password hint should be shown directly in the web page /// Show password hint |> Controls whether a password hint should be shown directly in the web page
/// if SMTP service is not configured. Not recommended for publicly-accessible instances as this /// if SMTP service is not configured. Not recommended for publicly-accessible instances as this
/// provides unauthenticated access to potentially sensitive data. /// provides unauthenticated access to potentially sensitive data.
@ -457,6 +463,10 @@ make_config! {
/// service is set, an icon request to Vaultwarden will return an HTTP redirect to the /// service is set, an icon request to Vaultwarden will return an HTTP redirect to the
/// corresponding icon at the external service. /// corresponding icon at the external service.
icon_service: String, false, def, "internal".to_string(); icon_service: String, false, def, "internal".to_string();
/// Internal
_icon_service_url: String, false, gen, |c| generate_icon_service_url(&c.icon_service);
/// Internal
_icon_service_csp: String, false, gen, |c| generate_icon_service_csp(&c.icon_service, &c._icon_service_url);
/// Icon redirect code |> The HTTP status code to use for redirects to an external icon service. /// Icon redirect code |> The HTTP status code to use for redirects to an external icon service.
/// The supported codes are 301 (legacy permanent), 302 (legacy temporary), 307 (temporary), and 308 (permanent). /// The supported codes are 301 (legacy permanent), 302 (legacy temporary), 307 (temporary), and 308 (permanent).
/// Temporary redirects are useful while testing different icon services, but once a service /// Temporary redirects are useful while testing different icon services, but once a service
@ -509,9 +519,15 @@ make_config! {
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely /// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
db_connection_retries: u32, false, def, 15; db_connection_retries: u32, false, def, 15;
/// Timeout when aquiring database connection
database_timeout: u64, false, def, 30;
/// Database connection pool size /// Database connection pool size
database_max_conns: u32, false, def, 10; database_max_conns: u32, false, def, 10;
/// Database connection init |> SQL statements to run when creating a new database connection, mainly useful for connection-scoped pragmas. If empty, a database-specific default is used.
database_conn_init: String, false, def, "".to_string();
/// Bypass admin page security (Know the risks!) |> Disables the Admin Token for the admin page so you may use your own auth in-front /// Bypass admin page security (Know the risks!) |> Disables the Admin Token for the admin page so you may use your own auth in-front
disable_admin_token: bool, true, def, false; disable_admin_token: bool, true, def, false;
@ -561,12 +577,14 @@ make_config! {
_enable_smtp: bool, true, def, true; _enable_smtp: bool, true, def, true;
/// Host /// Host
smtp_host: String, true, option; smtp_host: String, true, option;
/// Enable Secure SMTP |> (Explicit) - Enabling this by default would use STARTTLS (Standard ports 587 or 25) /// DEPRECATED smtp_ssl |> DEPRECATED - Please use SMTP_SECURITY
smtp_ssl: bool, true, def, true; smtp_ssl: bool, false, option;
/// Force TLS |> (Implicit) - Enabling this would force the use of an SSL/TLS connection, instead of upgrading an insecure one with STARTTLS (Standard port 465) /// DEPRECATED smtp_explicit_tls |> DEPRECATED - Please use SMTP_SECURITY
smtp_explicit_tls: bool, true, def, false; smtp_explicit_tls: bool, false, option;
/// Secure SMTP |> ("starttls", "force_tls", "off") Enable a secure connection. Default is "starttls" (Explicit - ports 587 or 25), "force_tls" (Implicit - port 465) or "off", no encryption
smtp_security: String, true, auto, |c| smtp_convert_deprecated_ssl_options(c.smtp_ssl, c.smtp_explicit_tls); // TODO: After deprecation make it `def, "starttls".to_string()`
/// Port /// Port
smtp_port: u16, true, auto, |c| if c.smtp_explicit_tls {465} else if c.smtp_ssl {587} else {25}; smtp_port: u16, true, auto, |c| if c.smtp_security == *"force_tls" {465} else if c.smtp_security == *"starttls" {587} else {25};
/// From Address /// From Address
smtp_from: String, true, def, String::new(); smtp_from: String, true, def, String::new();
/// From Name /// From Name
@ -593,8 +611,8 @@ make_config! {
email_2fa: _enable_email_2fa { email_2fa: _enable_email_2fa {
/// Enabled |> Disabling will prevent users from setting up new email 2FA and using existing email 2FA configured /// Enabled |> Disabling will prevent users from setting up new email 2FA and using existing email 2FA configured
_enable_email_2fa: bool, true, auto, |c| c._enable_smtp && c.smtp_host.is_some(); _enable_email_2fa: bool, true, auto, |c| c._enable_smtp && c.smtp_host.is_some();
/// Email token size |> Number of digits in an email token (min: 6, max: 19). Note that the Bitwarden clients are hardcoded to mention 6 digit codes regardless of this setting. /// Email token size |> Number of digits in an email 2FA token (min: 6, max: 255). Note that the Bitwarden clients are hardcoded to mention 6 digit codes regardless of this setting.
email_token_size: u32, true, def, 6; email_token_size: u8, true, def, 6;
/// Token expiration time |> Maximum time in seconds a token is valid. The time the user has to open email client and copy token. /// Token expiration time |> Maximum time in seconds a token is valid. The time the user has to open email client and copy token.
email_expiration_time: u64, true, def, 600; email_expiration_time: u64, true, def, 600;
/// Maximum attempts |> Maximum attempts before an email token is reset and a new email will need to be sent /// Maximum attempts |> Maximum attempts before an email token is reset and a new email will need to be sent
@ -649,6 +667,13 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
} }
if cfg._enable_smtp { if cfg._enable_smtp {
match cfg.smtp_security.as_str() {
"off" | "starttls" | "force_tls" => (),
_ => err!(
"`SMTP_SECURITY` is invalid. It needs to be one of the following options: starttls, force_tls or off"
),
}
if cfg.smtp_host.is_some() == cfg.smtp_from.is_empty() { if cfg.smtp_host.is_some() == cfg.smtp_from.is_empty() {
err!("Both `SMTP_HOST` and `SMTP_FROM` need to be set for email support") err!("Both `SMTP_HOST` and `SMTP_FROM` need to be set for email support")
} }
@ -668,10 +693,6 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
if cfg._enable_email_2fa && cfg.email_token_size < 6 { if cfg._enable_email_2fa && cfg.email_token_size < 6 {
err!("`EMAIL_TOKEN_SIZE` has a minimum size of 6") err!("`EMAIL_TOKEN_SIZE` has a minimum size of 6")
} }
if cfg._enable_email_2fa && cfg.email_token_size > 19 {
err!("`EMAIL_TOKEN_SIZE` has a maximum size of 19")
}
} }
// Check if the icon blacklist regex is valid // Check if the icon blacklist regex is valid
@ -731,6 +752,48 @@ fn extract_url_path(url: &str) -> String {
} }
} }
/// Generate the correct URL for the icon service.
/// This will be used within icons.rs to call the external icon service.
fn generate_icon_service_url(icon_service: &str) -> String {
match icon_service {
"internal" => "".to_string(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_string(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_string(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_string(),
_ => icon_service.to_string(),
}
}
/// Generate the CSP string needed to allow redirected icon fetching
fn generate_icon_service_csp(icon_service: &str, icon_service_url: &str) -> String {
// We split on the first '{', since that is the variable delimiter for an icon service URL.
// Everything up until the first '{' should be fixed and can be used as an CSP string.
let csp_string = match icon_service_url.split_once('{') {
Some((c, _)) => c.to_string(),
None => "".to_string(),
};
// Because Google does a second redirect to there gstatic.com domain, we need to add an extra csp string.
match icon_service {
"google" => csp_string + " https://*.gstatic.com/favicon",
_ => csp_string,
}
}
/// Convert the old SMTP_SSL and SMTP_EXPLICIT_TLS options
fn smtp_convert_deprecated_ssl_options(smtp_ssl: Option<bool>, smtp_explicit_tls: Option<bool>) -> String {
if smtp_explicit_tls.is_some() || smtp_ssl.is_some() {
println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead.");
}
if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() {
return "force_tls".to_string();
} else if smtp_ssl.is_some() && !smtp_ssl.unwrap() {
return "off".to_string();
}
// Return the default `starttls` in all other cases
"starttls".to_string()
}
impl Config { impl Config {
pub fn load() -> Result<Self, Error> { pub fn load() -> Result<Self, Error> {
// Loading from env and file // Loading from env and file
@ -747,6 +810,8 @@ impl Config {
Ok(Config { Ok(Config {
inner: RwLock::new(Inner { inner: RwLock::new(Inner {
rocket_shutdown_handle: None,
ws_shutdown_handle: None,
templates: load_templates(&config.templates_folder), templates: load_templates(&config.templates_folder),
config, config,
_env, _env,
@ -911,6 +976,26 @@ impl Config {
hb.render(name, data).map_err(Into::into) hb.render(name, data).map_err(Into::into)
} }
} }
pub fn set_rocket_shutdown_handle(&self, handle: rocket::Shutdown) {
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
}
pub fn set_ws_shutdown_handle(&self, handle: tokio::sync::oneshot::Sender<()>) {
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
}
pub fn shutdown(&self) {
if let Ok(mut c) = self.inner.write() {
if let Some(handle) = c.ws_shutdown_handle.take() {
handle.send(()).ok();
}
if let Some(handle) = c.rocket_shutdown_handle.take() {
handle.notify();
}
}
}
} }
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable}; use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
@ -984,7 +1069,7 @@ where
fn case_helper<'reg, 'rc>( fn case_helper<'reg, 'rc>(
h: &Helper<'reg, 'rc>, h: &Helper<'reg, 'rc>,
r: &'reg Handlebars, r: &'reg Handlebars<'_>,
ctx: &'rc Context, ctx: &'rc Context,
rc: &mut RenderContext<'reg, 'rc>, rc: &mut RenderContext<'reg, 'rc>,
out: &mut dyn Output, out: &mut dyn Output,
@ -1001,17 +1086,16 @@ fn case_helper<'reg, 'rc>(
fn js_escape_helper<'reg, 'rc>( fn js_escape_helper<'reg, 'rc>(
h: &Helper<'reg, 'rc>, h: &Helper<'reg, 'rc>,
_r: &'reg Handlebars, _r: &'reg Handlebars<'_>,
_ctx: &'rc Context, _ctx: &'rc Context,
_rc: &mut RenderContext<'reg, 'rc>, _rc: &mut RenderContext<'reg, 'rc>,
out: &mut dyn Output, out: &mut dyn Output,
) -> HelperResult { ) -> HelperResult {
let param = h.param(0).ok_or_else(|| RenderError::new("Param not found for helper \"js_escape\""))?; let param = h.param(0).ok_or_else(|| RenderError::new("Param not found for helper \"jsesc\""))?;
let no_quote = h.param(1).is_some(); let no_quote = h.param(1).is_some();
let value = let value = param.value().as_str().ok_or_else(|| RenderError::new("Param for helper \"jsesc\" is not a String"))?;
param.value().as_str().ok_or_else(|| RenderError::new("Param for helper \"js_escape\" is not a String"))?;
let mut escaped_value = value.replace('\\', "").replace('\'', "\\x22").replace('\"', "\\x27"); let mut escaped_value = value.replace('\\', "").replace('\'', "\\x22").replace('\"', "\\x27");
if !no_quote { if !no_quote {

28
src/crypto.rs

@ -6,8 +6,6 @@ use std::num::NonZeroU32;
use data_encoding::HEXLOWER; use data_encoding::HEXLOWER;
use ring::{digest, hmac, pbkdf2}; use ring::{digest, hmac, pbkdf2};
use crate::error::Error;
static DIGEST_ALG: pbkdf2::Algorithm = pbkdf2::PBKDF2_HMAC_SHA256; static DIGEST_ALG: pbkdf2::Algorithm = pbkdf2::PBKDF2_HMAC_SHA256;
const OUTPUT_LEN: usize = digest::SHA256_OUTPUT_LEN; const OUTPUT_LEN: usize = digest::SHA256_OUTPUT_LEN;
@ -65,6 +63,12 @@ pub fn get_random_string(alphabet: &[u8], num_chars: usize) -> String {
.collect() .collect()
} }
/// Generates a random numeric string.
pub fn get_random_string_numeric(num_chars: usize) -> String {
const ALPHABET: &[u8] = b"0123456789";
get_random_string(ALPHABET, num_chars)
}
/// Generates a random alphanumeric string. /// Generates a random alphanumeric string.
pub fn get_random_string_alphanum(num_chars: usize) -> String { pub fn get_random_string_alphanum(num_chars: usize) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
@ -87,23 +91,9 @@ pub fn generate_attachment_id() -> String {
generate_id(10) // 80 bits generate_id(10) // 80 bits
} }
pub fn generate_token(token_size: u32) -> Result<String, Error> { /// Generates a numeric token for email-based verifications.
// A u64 can represent all whole numbers up to 19 digits long. pub fn generate_email_token(token_size: u8) -> String {
if token_size > 19 { get_random_string_numeric(token_size as usize)
err!("Token size is limited to 19 digits")
}
let low: u64 = 0;
let high: u64 = 10u64.pow(token_size);
// Generate a random number in the range [low, high), then format it as a
// token of fixed width, left-padding with 0 as needed.
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let number: u64 = rng.gen_range(low..high);
let token = format!("{:0size$}", number, size = token_size as usize);
Ok(token)
} }
/// Generates a personal API key. /// Generates a personal API key.

220
src/db/mod.rs

@ -1,8 +1,20 @@
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection}; use std::{sync::Arc, time::Duration};
use diesel::{
connection::SimpleConnection,
r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection},
};
use rocket::{ use rocket::{
http::Status, http::Status,
outcome::IntoOutcome,
request::{FromRequest, Outcome}, request::{FromRequest, Outcome},
Request, State, Request,
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout,
}; };
use crate::{ use crate::{
@ -22,6 +34,23 @@ pub mod __mysql_schema;
#[path = "schemas/postgresql/schema.rs"] #[path = "schemas/postgresql/schema.rs"]
pub mod __postgresql_schema; pub mod __postgresql_schema;
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
},
}
}
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported // This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules! generate_connections { macro_rules! generate_connections {
( $( $name:ident: $ty:ty ),+ ) => { ( $( $name:ident: $ty:ty ),+ ) => {
@ -29,15 +58,74 @@ macro_rules! generate_connections {
#[derive(Eq, PartialEq)] #[derive(Eq, PartialEq)]
pub enum DbConnType { $( $name, )+ } pub enum DbConnType { $( $name, )+ }
pub struct DbConn {
conn: Arc<Mutex<Option<DbConnInner>>>,
permit: Option<OwnedSemaphorePermit>,
}
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
pub enum DbConn { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ } pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
#[derive(Debug)]
pub struct DbConnOptions {
pub init_stmts: String,
}
$( // Based on <https://stackoverflow.com/a/57717533>.
#[cfg($name)]
impl CustomizeConnection<$ty, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut $ty) -> Result<(), diesel::r2d2::Error> {
(|| {
if !self.init_stmts.is_empty() {
conn.batch_execute(&self.init_stmts)?;
}
Ok(())
})().map_err(diesel::r2d2::Error::QueryError)
}
})+
#[derive(Clone)]
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<DbPoolInner>,
semaphore: Arc<Semaphore>
}
#[allow(non_camel_case_types)] #[allow(non_camel_case_types)]
#[derive(Clone)] #[derive(Clone)]
pub enum DbPool { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ } pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
impl Drop for DbConn {
fn drop(&mut self) {
let conn = self.conn.clone();
let permit = self.permit.take();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
tokio::task::spawn_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
if let Some(conn) = conn.take() {
drop(conn);
}
// Drop permit after the connection is dropped
drop(permit);
});
}
}
impl Drop for DbPool {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
}
}
impl DbPool { impl DbPool {
// For the given database URL, guess it's type, run migrations create pool and return it // For the given database URL, guess its type, run migrations, create pool, and return it
#[allow(clippy::diverging_sub_expression)]
pub fn from_config() -> Result<Self, Error> { pub fn from_config() -> Result<Self, Error> {
let url = CONFIG.database_url(); let url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&url)?; let conn_type = DbConnType::from_url(&url)?;
@ -50,9 +138,16 @@ macro_rules! generate_connections {
let manager = ConnectionManager::new(&url); let manager = ConnectionManager::new(&url);
let pool = Pool::builder() let pool = Pool::builder()
.max_size(CONFIG.database_max_conns()) .max_size(CONFIG.database_max_conns())
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
.connection_customizer(Box::new(DbConnOptions{
init_stmts: conn_type.get_init_stmts()
}))
.build(manager) .build(manager)
.map_res("Failed to create pool")?; .map_res("Failed to create pool")?;
return Ok(Self::$name(pool)); return Ok(DbPool {
pool: Some(DbPoolInner::$name(pool)),
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
});
} }
#[cfg(not($name))] #[cfg(not($name))]
#[allow(unreachable_code)] #[allow(unreachable_code)]
@ -61,10 +156,26 @@ macro_rules! generate_connections {
)+ } )+ }
} }
// Get a connection from the pool // Get a connection from the pool
pub fn get(&self) -> Result<DbConn, Error> { pub async fn get(&self) -> Result<DbConn, Error> {
match self { $( let duration = Duration::from_secs(CONFIG.database_timeout());
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
Ok(p) => p.expect("Semaphore should be open"),
Err(_) => {
err!("Timeout waiting for database connection");
}
};
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
#[cfg($name)] #[cfg($name)]
Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)), DbPoolInner::$name(p) => {
let pool = p.clone();
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
return Ok(DbConn {
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
permit: Some(permit)
});
},
)+ } )+ }
} }
} }
@ -104,6 +215,23 @@ impl DbConnType {
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled") err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
} }
} }
pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init();
if !init_stmts.is_empty() {
init_stmts
} else {
self.default_init_stmts()
}
}
pub fn default_init_stmts(&self) -> String {
match self {
Self::sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
Self::mysql => "".to_string(),
Self::postgresql => "".to_string(),
}
}
} }
#[macro_export] #[macro_export]
@ -113,42 +241,52 @@ macro_rules! db_run {
db_run! { $conn: sqlite, mysql, postgresql $body } db_run! { $conn: sqlite, mysql, postgresql $body }
}; };
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db // Different code for each db
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{ ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*; #[allow(unused)] use diesel::prelude::*;
match $conn { #[allow(unused)] use $crate::db::FromDb;
$($(
let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
$($(
#[cfg($db)] #[cfg($db)]
crate::db::DbConn::$db(ref $conn) => { $crate::db::DbConnInner::$db($conn) => {
paste::paste! { paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *}; #[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*; #[allow(unused)] use [<__ $db _model>]::*;
#[allow(unused)] use crate::db::FromDb;
} }
$body
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
}, },
)+)+ )+)+
}} }
}; }};
// Same for all dbs
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db ( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
#[allow(unused)] use diesel::prelude::*; #[allow(unused)] use diesel::prelude::*;
#[allow(unused_variables)] #[allow(unused)] use $crate::db::FromDb;
match $conn {
$($( let conn = $conn.conn.clone();
let mut conn = conn.lock_owned().await;
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
$($(
#[cfg($db)] #[cfg($db)]
crate::db::DbConn::$db(ref $conn) => { $crate::db::DbConnInner::$db($conn) => {
$body paste::paste! {
#[allow(unused)] use $crate::db::[<__ $db _schema>]::{self as schema, *};
// @ RAW: #[allow(unused)] use [<__ $db _model>]::*;
}
tokio::task::block_in_place(move || { $body }) // Run blocking can't be used due to the 'static limitation, use block_in_place instead
}, },
)+)+ )+)+
} }
}; }};
} }
pub trait FromDb { pub trait FromDb {
@ -201,7 +339,7 @@ macro_rules! db_object {
paste::paste! { paste::paste! {
#[allow(unused)] use super::*; #[allow(unused)] use super::*;
#[allow(unused)] use diesel::prelude::*; #[allow(unused)] use diesel::prelude::*;
#[allow(unused)] use crate::db::[<__ $db _schema>]::*; #[allow(unused)] use $crate::db::[<__ $db _schema>]::*;
$( #[$attr] )* $( #[$attr] )*
pub struct [<$name Db>] { $( pub struct [<$name Db>] { $(
@ -213,7 +351,7 @@ macro_rules! db_object {
#[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } } #[inline(always)] pub fn to_db(x: &super::$name) -> Self { Self { $( $field: x.$field.clone(), )+ } }
} }
impl crate::db::FromDb for [<$name Db>] { impl $crate::db::FromDb for [<$name Db>] {
type Output = super::$name; type Output = super::$name;
#[allow(clippy::wrong_self_convention)] #[allow(clippy::wrong_self_convention)]
#[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } } #[inline(always)] fn from_db(self) -> Self::Output { super::$name { $( $field: self.$field, )+ } }
@ -227,9 +365,10 @@ pub mod models;
/// Creates a back-up of the sqlite database /// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported. /// MySQL/MariaDB and PostgreSQL are not supported.
pub fn backup_database(conn: &DbConn) -> Result<(), Error> { pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
db_run! {@raw conn: db_run! {@raw conn:
postgresql, mysql { postgresql, mysql {
let _ = conn;
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature"); err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
} }
sqlite { sqlite {
@ -244,7 +383,7 @@ pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
} }
/// Get the SQL Server version /// Get the SQL Server version
pub fn get_sql_server_version(conn: &DbConn) -> String { pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! {@raw conn: db_run! {@raw conn:
postgresql, mysql { postgresql, mysql {
no_arg_sql_function!(version, diesel::sql_types::Text); no_arg_sql_function!(version, diesel::sql_types::Text);
@ -260,15 +399,14 @@ pub fn get_sql_server_version(conn: &DbConn) -> String {
/// Attempts to retrieve a single connection from the managed database pool. If /// Attempts to retrieve a single connection from the managed database pool. If
/// no pool is currently managed, fails with an `InternalServerError` status. If /// no pool is currently managed, fails with an `InternalServerError` status. If
/// no connections are available, fails with a `ServiceUnavailable` status. /// no connections are available, fails with a `ServiceUnavailable` status.
impl<'a, 'r> FromRequest<'a, 'r> for DbConn { #[rocket::async_trait]
impl<'r> FromRequest<'r> for DbConn {
type Error = (); type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<DbConn, ()> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
// https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c match request.rocket().state::<DbPool>() {
let pool = try_outcome!(request.guard::<State<DbPool>>()); Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
match pool.get() { None => Outcome::Failure((Status::InternalServerError, ())),
Ok(conn) => Outcome::Success(conn),
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
} }
} }
} }

37
src/db/models/attachment.rs

@ -2,14 +2,12 @@ use std::io::ErrorKind;
use serde_json::Value; use serde_json::Value;
use super::Cipher;
use crate::CONFIG; use crate::CONFIG;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "attachments"] #[table_name = "attachments"]
#[changeset_options(treat_none_as_null="true")] #[changeset_options(treat_none_as_null="true")]
#[belongs_to(super::Cipher, foreign_key = "cipher_uuid")]
#[primary_key(id)] #[primary_key(id)]
pub struct Attachment { pub struct Attachment {
pub id: String, pub id: String,
@ -60,7 +58,7 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Attachment { impl Attachment {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(attachments::table) match diesel::replace_into(attachments::table)
@ -92,7 +90,7 @@ impl Attachment {
} }
} }
pub fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
crate::util::retry( crate::util::retry(
|| diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn), || diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
@ -116,14 +114,14 @@ impl Attachment {
}} }}
} }
pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult {
for attachment in Attachment::find_by_cipher(cipher_uuid, conn) { for attachment in Attachment::find_by_cipher(cipher_uuid, conn).await {
attachment.delete(conn)?; attachment.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn find_by_id(id: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_id(id: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.filter(attachments::id.eq(id.to_lowercase())) .filter(attachments::id.eq(id.to_lowercase()))
@ -133,7 +131,7 @@ impl Attachment {
}} }}
} }
pub fn find_by_cipher(cipher_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_cipher(cipher_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid)) .filter(attachments::cipher_uuid.eq(cipher_uuid))
@ -143,7 +141,7 @@ impl Attachment {
}} }}
} }
pub fn size_by_user(user_uuid: &str, conn: &DbConn) -> i64 { pub async fn size_by_user(user_uuid: &str, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
let result: Option<i64> = attachments::table let result: Option<i64> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -155,7 +153,7 @@ impl Attachment {
}} }}
} }
pub fn count_by_user(user_uuid: &str, conn: &DbConn) -> i64 { pub async fn count_by_user(user_uuid: &str, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -166,7 +164,7 @@ impl Attachment {
}} }}
} }
pub fn size_by_org(org_uuid: &str, conn: &DbConn) -> i64 { pub async fn size_by_org(org_uuid: &str, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
let result: Option<i64> = attachments::table let result: Option<i64> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -178,7 +176,7 @@ impl Attachment {
}} }}
} }
pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
attachments::table attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid))) .left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
@ -188,4 +186,15 @@ impl Attachment {
.unwrap_or(0) .unwrap_or(0)
}} }}
} }
pub async fn find_all_by_ciphers(cipher_uuids: &Vec<String>, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
attachments::table
.filter(attachments::cipher_uuid.eq_any(cipher_uuids))
.select(attachments::all_columns)
.load::<AttachmentDb>(conn)
.expect("Error loading attachments")
.from_db()
}}
}
} }

280
src/db/models/cipher.rs

@ -1,19 +1,17 @@
use crate::CONFIG;
use chrono::{Duration, NaiveDateTime, Utc}; use chrono::{Duration, NaiveDateTime, Utc};
use serde_json::Value; use serde_json::Value;
use crate::CONFIG; use super::{Attachment, CollectionCipher, Favorite, FolderCipher, User, UserOrgStatus, UserOrgType, UserOrganization};
use super::{ use crate::api::core::CipherSyncData;
Attachment, CollectionCipher, Favorite, FolderCipher, Organization, User, UserOrgStatus, UserOrgType,
UserOrganization, use std::borrow::Cow;
};
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "ciphers"] #[table_name = "ciphers"]
#[changeset_options(treat_none_as_null="true")] #[changeset_options(treat_none_as_null="true")]
#[belongs_to(User, foreign_key = "user_uuid")]
#[belongs_to(Organization, foreign_key = "organization_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct Cipher { pub struct Cipher {
pub uuid: String, pub uuid: String,
@ -82,22 +80,32 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Cipher { impl Cipher {
pub fn to_json(&self, host: &str, user_uuid: &str, conn: &DbConn) -> Value { pub async fn to_json(
&self,
host: &str,
user_uuid: &str,
cipher_sync_data: Option<&CipherSyncData>,
conn: &DbConn,
) -> Value {
use crate::util::format_date; use crate::util::format_date;
let attachments = Attachment::find_by_cipher(&self.uuid, conn); let mut attachments_json: Value = Value::Null;
// When there are no attachments use null instead of an empty array if let Some(cipher_sync_data) = cipher_sync_data {
let attachments_json = if attachments.is_empty() { if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid) {
Value::Null attachments_json = attachments.iter().map(|c| c.to_json(host)).collect();
}
} else { } else {
attachments.iter().map(|c| c.to_json(host)).collect() let attachments = Attachment::find_by_cipher(&self.uuid, conn).await;
}; if !attachments.is_empty() {
attachments_json = attachments.iter().map(|c| c.to_json(host)).collect()
}
}
let fields_json = self.fields.as_ref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or(Value::Null); let fields_json = self.fields.as_ref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or(Value::Null);
let password_history_json = let password_history_json =
self.password_history.as_ref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or(Value::Null); self.password_history.as_ref().and_then(|s| serde_json::from_str(s).ok()).unwrap_or(Value::Null);
let (read_only, hide_passwords) = match self.get_access_restrictions(user_uuid, conn) { let (read_only, hide_passwords) = match self.get_access_restrictions(user_uuid, cipher_sync_data, conn).await {
Some((ro, hp)) => (ro, hp), Some((ro, hp)) => (ro, hp),
None => { None => {
error!("Cipher ownership assertion failure"); error!("Cipher ownership assertion failure");
@ -109,7 +117,7 @@ impl Cipher {
// If not passing an empty object, mobile clients will crash. // If not passing an empty object, mobile clients will crash.
let mut type_data_json: Value = serde_json::from_str(&self.data).unwrap_or_else(|_| json!({})); let mut type_data_json: Value = serde_json::from_str(&self.data).unwrap_or_else(|_| json!({}));
// NOTE: This was marked as *Backwards Compatibilty Code*, but as of January 2021 this is still being used by upstream // NOTE: This was marked as *Backwards Compatibility Code*, but as of January 2021 this is still being used by upstream
// Set the first element of the Uris array as Uri, this is needed several (mobile) clients. // Set the first element of the Uris array as Uri, this is needed several (mobile) clients.
if self.atype == 1 { if self.atype == 1 {
if type_data_json["Uris"].is_array() { if type_data_json["Uris"].is_array() {
@ -124,13 +132,23 @@ impl Cipher {
// Clone the type_data and add some default value. // Clone the type_data and add some default value.
let mut data_json = type_data_json.clone(); let mut data_json = type_data_json.clone();
// NOTE: This was marked as *Backwards Compatibilty Code*, but as of January 2021 this is still being used by upstream // NOTE: This was marked as *Backwards Compatibility Code*, but as of January 2021 this is still being used by upstream
// data_json should always contain the following keys with every atype // data_json should always contain the following keys with every atype
data_json["Fields"] = json!(fields_json); data_json["Fields"] = json!(fields_json);
data_json["Name"] = json!(self.name); data_json["Name"] = json!(self.name);
data_json["Notes"] = json!(self.notes); data_json["Notes"] = json!(self.notes);
data_json["PasswordHistory"] = json!(password_history_json); data_json["PasswordHistory"] = json!(password_history_json);
let collection_ids = if let Some(cipher_sync_data) = cipher_sync_data {
if let Some(cipher_collections) = cipher_sync_data.cipher_collections.get(&self.uuid) {
Cow::from(cipher_collections)
} else {
Cow::from(Vec::with_capacity(0))
}
} else {
Cow::from(self.get_collections(user_uuid, conn).await)
};
// There are three types of cipher response models in upstream // There are three types of cipher response models in upstream
// Bitwarden: "cipherMini", "cipher", and "cipherDetails" (in order // Bitwarden: "cipherMini", "cipher", and "cipherDetails" (in order
// of increasing level of detail). vaultwarden currently only // of increasing level of detail). vaultwarden currently only
@ -144,8 +162,8 @@ impl Cipher {
"Type": self.atype, "Type": self.atype,
"RevisionDate": format_date(&self.updated_at), "RevisionDate": format_date(&self.updated_at),
"DeletedDate": self.deleted_at.map_or(Value::Null, |d| Value::String(format_date(&d))), "DeletedDate": self.deleted_at.map_or(Value::Null, |d| Value::String(format_date(&d))),
"FolderId": self.get_folder_uuid(user_uuid, conn), "FolderId": if let Some(cipher_sync_data) = cipher_sync_data { cipher_sync_data.cipher_folders.get(&self.uuid).map(|c| c.to_string() ) } else { self.get_folder_uuid(user_uuid, conn).await },
"Favorite": self.is_favorite(user_uuid, conn), "Favorite": if let Some(cipher_sync_data) = cipher_sync_data { cipher_sync_data.cipher_favorites.contains(&self.uuid) } else { self.is_favorite(user_uuid, conn).await },
"Reprompt": self.reprompt.unwrap_or(RepromptType::None as i32), "Reprompt": self.reprompt.unwrap_or(RepromptType::None as i32),
"OrganizationId": self.organization_uuid, "OrganizationId": self.organization_uuid,
"Attachments": attachments_json, "Attachments": attachments_json,
@ -154,7 +172,7 @@ impl Cipher {
"OrganizationUseTotp": true, "OrganizationUseTotp": true,
// This field is specific to the cipherDetails type. // This field is specific to the cipherDetails type.
"CollectionIds": self.get_collections(user_uuid, conn), "CollectionIds": collection_ids,
"Name": self.name, "Name": self.name,
"Notes": self.notes, "Notes": self.notes,
@ -189,28 +207,28 @@ impl Cipher {
json_object json_object
} }
pub fn update_users_revision(&self, conn: &DbConn) -> Vec<String> { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<String> {
let mut user_uuids = Vec::new(); let mut user_uuids = Vec::new();
match self.user_uuid { match self.user_uuid {
Some(ref user_uuid) => { Some(ref user_uuid) => {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone()) user_uuids.push(user_uuid.clone())
} }
None => { None => {
// Belongs to Organization, need to update affected users // Belongs to Organization, need to update affected users
if let Some(ref org_uuid) = self.organization_uuid { if let Some(ref org_uuid) = self.organization_uuid {
UserOrganization::find_by_cipher_and_org(&self.uuid, org_uuid, conn).iter().for_each(|user_org| { for user_org in UserOrganization::find_by_cipher_and_org(&self.uuid, org_uuid, conn).await.iter() {
User::update_uuid_revision(&user_org.user_uuid, conn); User::update_uuid_revision(&user_org.user_uuid, conn).await;
user_uuids.push(user_org.user_uuid.clone()) user_uuids.push(user_org.user_uuid.clone())
}); }
} }
} }
}; };
user_uuids user_uuids
} }
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
@ -244,13 +262,13 @@ impl Cipher {
} }
} }
pub fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
FolderCipher::delete_all_by_cipher(&self.uuid, conn)?; FolderCipher::delete_all_by_cipher(&self.uuid, conn).await?;
CollectionCipher::delete_all_by_cipher(&self.uuid, conn)?; CollectionCipher::delete_all_by_cipher(&self.uuid, conn).await?;
Attachment::delete_all_by_cipher(&self.uuid, conn)?; Attachment::delete_all_by_cipher(&self.uuid, conn).await?;
Favorite::delete_all_by_cipher(&self.uuid, conn)?; Favorite::delete_all_by_cipher(&self.uuid, conn).await?;
db_run! { conn: { db_run! { conn: {
diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid))) diesel::delete(ciphers::table.filter(ciphers::uuid.eq(&self.uuid)))
@ -259,54 +277,55 @@ impl Cipher {
}} }}
} }
pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult {
for cipher in Self::find_by_org(org_uuid, conn) { // TODO: Optimize this by executing a DELETE directly on the database, instead of first fetching.
cipher.delete(conn)?; for cipher in Self::find_by_org(org_uuid, conn).await {
cipher.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
for cipher in Self::find_owned_by_user(user_uuid, conn) { for cipher in Self::find_owned_by_user(user_uuid, conn).await {
cipher.delete(conn)?; cipher.delete(conn).await?;
} }
Ok(()) Ok(())
} }
/// Purge all ciphers that are old enough to be auto-deleted. /// Purge all ciphers that are old enough to be auto-deleted.
pub fn purge_trash(conn: &DbConn) { pub async fn purge_trash(conn: &DbConn) {
if let Some(auto_delete_days) = CONFIG.trash_auto_delete_days() { if let Some(auto_delete_days) = CONFIG.trash_auto_delete_days() {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
let dt = now - Duration::days(auto_delete_days); let dt = now - Duration::days(auto_delete_days);
for cipher in Self::find_deleted_before(&dt, conn) { for cipher in Self::find_deleted_before(&dt, conn).await {
cipher.delete(conn).ok(); cipher.delete(conn).await.ok();
} }
} }
} }
pub fn move_to_folder(&self, folder_uuid: Option<String>, user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn move_to_folder(&self, folder_uuid: Option<String>, user_uuid: &str, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
match (self.get_folder_uuid(user_uuid, conn), folder_uuid) { match (self.get_folder_uuid(user_uuid, conn).await, folder_uuid) {
// No changes // No changes
(None, None) => Ok(()), (None, None) => Ok(()),
(Some(ref old), Some(ref new)) if old == new => Ok(()), (Some(ref old), Some(ref new)) if old == new => Ok(()),
// Add to folder // Add to folder
(None, Some(new)) => FolderCipher::new(&new, &self.uuid).save(conn), (None, Some(new)) => FolderCipher::new(&new, &self.uuid).save(conn).await,
// Remove from folder // Remove from folder
(Some(old), None) => match FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn) { (Some(old), None) => match FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn).await {
Some(old) => old.delete(conn), Some(old) => old.delete(conn).await,
None => err!("Couldn't move from previous folder"), None => err!("Couldn't move from previous folder"),
}, },
// Move to another folder // Move to another folder
(Some(old), Some(new)) => { (Some(old), Some(new)) => {
if let Some(old) = FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn) { if let Some(old) = FolderCipher::find_by_folder_and_cipher(&old, &self.uuid, conn).await {
old.delete(conn)?; old.delete(conn).await?;
} }
FolderCipher::new(&new, &self.uuid).save(conn) FolderCipher::new(&new, &self.uuid).save(conn).await
} }
} }
} }
@ -317,13 +336,21 @@ impl Cipher {
} }
/// Returns whether this cipher is owned by an org in which the user has full access. /// Returns whether this cipher is owned by an org in which the user has full access.
pub fn is_in_full_access_org(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_in_full_access_org(
&self,
user_uuid: &str,
cipher_sync_data: Option<&CipherSyncData>,
conn: &DbConn,
) -> bool {
if let Some(ref org_uuid) = self.organization_uuid { if let Some(ref org_uuid) = self.organization_uuid {
if let Some(user_org) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn) { if let Some(cipher_sync_data) = cipher_sync_data {
if let Some(cached_user_org) = cipher_sync_data.user_organizations.get(org_uuid) {
return cached_user_org.has_full_access();
}
} else if let Some(user_org) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn).await {
return user_org.has_full_access(); return user_org.has_full_access();
} }
} }
false false
} }
@ -332,18 +359,62 @@ impl Cipher {
/// not in any collection the user has access to. Otherwise, the user has /// not in any collection the user has access to. Otherwise, the user has
/// access to this cipher, and Some(read_only, hide_passwords) represents /// access to this cipher, and Some(read_only, hide_passwords) represents
/// the access restrictions. /// the access restrictions.
pub fn get_access_restrictions(&self, user_uuid: &str, conn: &DbConn) -> Option<(bool, bool)> { pub async fn get_access_restrictions(
&self,
user_uuid: &str,
cipher_sync_data: Option<&CipherSyncData>,
conn: &DbConn,
) -> Option<(bool, bool)> {
// Check whether this cipher is directly owned by the user, or is in // Check whether this cipher is directly owned by the user, or is in
// a collection that the user has full access to. If so, there are no // a collection that the user has full access to. If so, there are no
// access restrictions. // access restrictions.
if self.is_owned_by_user(user_uuid) || self.is_in_full_access_org(user_uuid, conn) { if self.is_owned_by_user(user_uuid) || self.is_in_full_access_org(user_uuid, cipher_sync_data, conn).await {
return Some((false, false)); return Some((false, false));
} }
let rows = if let Some(cipher_sync_data) = cipher_sync_data {
let mut rows: Vec<(bool, bool)> = Vec::new();
if let Some(collections) = cipher_sync_data.cipher_collections.get(&self.uuid) {
for collection in collections {
if let Some(uc) = cipher_sync_data.user_collections.get(collection) {
rows.push((uc.read_only, uc.hide_passwords));
}
}
}
rows
} else {
self.get_collections_access_flags(user_uuid, conn).await
};
if rows.is_empty() {
// This cipher isn't in any collections accessible to the user.
return None;
}
// A cipher can be in multiple collections with inconsistent access flags.
// For example, a cipher could be in one collection where the user has
// read-only access, but also in another collection where the user has
// read/write access. For a flag to be in effect for a cipher, upstream
// requires all collections the cipher is in to have that flag set.
// Therefore, we do a boolean AND of all values in each of the `read_only`
// and `hide_passwords` columns. This could ideally be done as part of the
// query, but Diesel doesn't support a min() or bool_and() function on
// booleans and this behavior isn't portable anyway.
let mut read_only = true;
let mut hide_passwords = true;
for (ro, hp) in rows.iter() {
read_only &= ro;
hide_passwords &= hp;
}
Some((read_only, hide_passwords))
}
pub async fn get_collections_access_flags(&self, user_uuid: &str, conn: &DbConn) -> Vec<(bool, bool)> {
db_run! {conn: { db_run! {conn: {
// Check whether this cipher is in any collections accessible to the // Check whether this cipher is in any collections accessible to the
// user. If so, retrieve the access flags for each collection. // user. If so, retrieve the access flags for each collection.
let rows = ciphers::table ciphers::table
.filter(ciphers::uuid.eq(&self.uuid)) .filter(ciphers::uuid.eq(&self.uuid))
.inner_join(ciphers_collections::table.on( .inner_join(ciphers_collections::table.on(
ciphers::uuid.eq(ciphers_collections::cipher_uuid))) ciphers::uuid.eq(ciphers_collections::cipher_uuid)))
@ -352,58 +423,35 @@ impl Cipher {
.and(users_collections::user_uuid.eq(user_uuid)))) .and(users_collections::user_uuid.eq(user_uuid))))
.select((users_collections::read_only, users_collections::hide_passwords)) .select((users_collections::read_only, users_collections::hide_passwords))
.load::<(bool, bool)>(conn) .load::<(bool, bool)>(conn)
.expect("Error getting access restrictions"); .expect("Error getting access restrictions")
if rows.is_empty() {
// This cipher isn't in any collections accessible to the user.
return None;
}
// A cipher can be in multiple collections with inconsistent access flags.
// For example, a cipher could be in one collection where the user has
// read-only access, but also in another collection where the user has
// read/write access. For a flag to be in effect for a cipher, upstream
// requires all collections the cipher is in to have that flag set.
// Therefore, we do a boolean AND of all values in each of the `read_only`
// and `hide_passwords` columns. This could ideally be done as part of the
// query, but Diesel doesn't support a min() or bool_and() function on
// booleans and this behavior isn't portable anyway.
let mut read_only = true;
let mut hide_passwords = true;
for (ro, hp) in rows.iter() {
read_only &= ro;
hide_passwords &= hp;
}
Some((read_only, hide_passwords))
}} }}
} }
pub fn is_write_accessible_to_user(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_write_accessible_to_user(&self, user_uuid: &str, conn: &DbConn) -> bool {
match self.get_access_restrictions(user_uuid, conn) { match self.get_access_restrictions(user_uuid, None, conn).await {
Some((read_only, _hide_passwords)) => !read_only, Some((read_only, _hide_passwords)) => !read_only,
None => false, None => false,
} }
} }
pub fn is_accessible_to_user(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_accessible_to_user(&self, user_uuid: &str, conn: &DbConn) -> bool {
self.get_access_restrictions(user_uuid, conn).is_some() self.get_access_restrictions(user_uuid, None, conn).await.is_some()
} }
// Returns whether this cipher is a favorite of the specified user. // Returns whether this cipher is a favorite of the specified user.
pub fn is_favorite(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_favorite(&self, user_uuid: &str, conn: &DbConn) -> bool {
Favorite::is_favorite(&self.uuid, user_uuid, conn) Favorite::is_favorite(&self.uuid, user_uuid, conn).await
} }
// Sets whether this cipher is a favorite of the specified user. // Sets whether this cipher is a favorite of the specified user.
pub fn set_favorite(&self, favorite: Option<bool>, user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn set_favorite(&self, favorite: Option<bool>, user_uuid: &str, conn: &DbConn) -> EmptyResult {
match favorite { match favorite {
None => Ok(()), // No change requested. None => Ok(()), // No change requested.
Some(status) => Favorite::set_favorite(status, &self.uuid, user_uuid, conn), Some(status) => Favorite::set_favorite(status, &self.uuid, user_uuid, conn).await,
} }
} }
pub fn get_folder_uuid(&self, user_uuid: &str, conn: &DbConn) -> Option<String> { pub async fn get_folder_uuid(&self, user_uuid: &str, conn: &DbConn) -> Option<String> {
db_run! {conn: { db_run! {conn: {
folders_ciphers::table folders_ciphers::table
.inner_join(folders::table) .inner_join(folders::table)
@ -415,7 +463,7 @@ impl Cipher {
}} }}
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter(ciphers::uuid.eq(uuid)) .filter(ciphers::uuid.eq(uuid))
@ -437,7 +485,7 @@ impl Cipher {
// true, then the non-interesting ciphers will not be returned. As a // true, then the non-interesting ciphers will not be returned. As a
// result, those ciphers will not appear in "My Vault" for the org // result, those ciphers will not appear in "My Vault" for the org
// owner/admin, but they can still be accessed via the org vault view. // owner/admin, but they can still be accessed via the org vault view.
pub fn find_by_user(user_uuid: &str, visible_only: bool, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &str, visible_only: bool, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
let mut query = ciphers::table let mut query = ciphers::table
.left_join(ciphers_collections::table.on( .left_join(ciphers_collections::table.on(
@ -472,12 +520,12 @@ impl Cipher {
} }
// Find all ciphers visible to the specified user. // Find all ciphers visible to the specified user.
pub fn find_by_user_visible(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user_visible(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
Self::find_by_user(user_uuid, true, conn) Self::find_by_user(user_uuid, true, conn).await
} }
// Find all ciphers directly owned by the specified user. // Find all ciphers directly owned by the specified user.
pub fn find_owned_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_owned_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter( .filter(
@ -488,7 +536,7 @@ impl Cipher {
}} }}
} }
pub fn count_owned_by_user(user_uuid: &str, conn: &DbConn) -> i64 { pub async fn count_owned_by_user(user_uuid: &str, conn: &DbConn) -> i64 {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter(ciphers::user_uuid.eq(user_uuid)) .filter(ciphers::user_uuid.eq(user_uuid))
@ -499,7 +547,7 @@ impl Cipher {
}} }}
} }
pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
@ -507,7 +555,7 @@ impl Cipher {
}} }}
} }
pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter(ciphers::organization_uuid.eq(org_uuid)) .filter(ciphers::organization_uuid.eq(org_uuid))
@ -518,7 +566,7 @@ impl Cipher {
}} }}
} }
pub fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
folders_ciphers::table.inner_join(ciphers::table) folders_ciphers::table.inner_join(ciphers::table)
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
@ -528,7 +576,7 @@ impl Cipher {
} }
/// Find all ciphers that were deleted before the specified datetime. /// Find all ciphers that were deleted before the specified datetime.
pub fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> { pub async fn find_deleted_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
ciphers::table ciphers::table
.filter(ciphers::deleted_at.lt(dt)) .filter(ciphers::deleted_at.lt(dt))
@ -536,7 +584,7 @@ impl Cipher {
}} }}
} }
pub fn get_collections(&self, user_id: &str, conn: &DbConn) -> Vec<String> { pub async fn get_collections(&self, user_id: &str, conn: &DbConn) -> Vec<String> {
db_run! {conn: { db_run! {conn: {
ciphers_collections::table ciphers_collections::table
.inner_join(collections::table.on( .inner_join(collections::table.on(
@ -562,4 +610,32 @@ impl Cipher {
.load::<String>(conn).unwrap_or_default() .load::<String>(conn).unwrap_or_default()
}} }}
} }
/// Return a Vec with (cipher_uuid, collection_uuid)
/// This is used during a full sync so we only need one query for all collections accessible.
pub async fn get_collections_with_cipher_by_user(user_id: &str, conn: &DbConn) -> Vec<(String, String)> {
db_run! {conn: {
ciphers_collections::table
.inner_join(collections::table.on(
collections::uuid.eq(ciphers_collections::collection_uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::org_uuid.eq(collections::org_uuid).and(
users_organizations::user_uuid.eq(user_id)
)
))
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(ciphers_collections::collection_uuid).and(
users_collections::user_uuid.eq(user_id)
)
))
.filter(users_collections::user_uuid.eq(user_id).or( // User has access to collection
users_organizations::access_all.eq(true).or( // User has access all
users_organizations::atype.le(UserOrgType::Admin as i32) // User is admin or owner
)
))
.select(ciphers_collections::all_columns)
.load::<(String, String)>(conn).unwrap_or_default()
}}
}
} }

143
src/db/models/collection.rs

@ -1,11 +1,10 @@
use serde_json::Value; use serde_json::Value;
use super::{Cipher, Organization, User, UserOrgStatus, UserOrgType, UserOrganization}; use super::{User, UserOrgStatus, UserOrgType, UserOrganization};
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "collections"] #[table_name = "collections"]
#[belongs_to(Organization, foreign_key = "org_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct Collection { pub struct Collection {
pub uuid: String, pub uuid: String,
@ -13,10 +12,8 @@ db_object! {
pub name: String, pub name: String,
} }
#[derive(Identifiable, Queryable, Insertable, Associations)] #[derive(Identifiable, Queryable, Insertable)]
#[table_name = "users_collections"] #[table_name = "users_collections"]
#[belongs_to(User, foreign_key = "user_uuid")]
#[belongs_to(Collection, foreign_key = "collection_uuid")]
#[primary_key(user_uuid, collection_uuid)] #[primary_key(user_uuid, collection_uuid)]
pub struct CollectionUser { pub struct CollectionUser {
pub user_uuid: String, pub user_uuid: String,
@ -25,10 +22,8 @@ db_object! {
pub hide_passwords: bool, pub hide_passwords: bool,
} }
#[derive(Identifiable, Queryable, Insertable, Associations)] #[derive(Identifiable, Queryable, Insertable)]
#[table_name = "ciphers_collections"] #[table_name = "ciphers_collections"]
#[belongs_to(Cipher, foreign_key = "cipher_uuid")]
#[belongs_to(Collection, foreign_key = "collection_uuid")]
#[primary_key(cipher_uuid, collection_uuid)] #[primary_key(cipher_uuid, collection_uuid)]
pub struct CollectionCipher { pub struct CollectionCipher {
pub cipher_uuid: String, pub cipher_uuid: String,
@ -57,11 +52,32 @@ impl Collection {
}) })
} }
pub fn to_json_details(&self, user_uuid: &str, conn: &DbConn) -> Value { pub async fn to_json_details(
&self,
user_uuid: &str,
cipher_sync_data: Option<&crate::api::core::CipherSyncData>,
conn: &DbConn,
) -> Value {
let (read_only, hide_passwords) = if let Some(cipher_sync_data) = cipher_sync_data {
match cipher_sync_data.user_organizations.get(&self.org_uuid) {
Some(uo) if uo.has_full_access() => (false, false),
Some(_) => {
if let Some(uc) = cipher_sync_data.user_collections.get(&self.uuid) {
(uc.read_only, uc.hide_passwords)
} else {
(false, false)
}
}
_ => (true, true),
}
} else {
(!self.is_writable_by_user(user_uuid, conn).await, self.hide_passwords_for_user(user_uuid, conn).await)
};
let mut json_object = self.to_json(); let mut json_object = self.to_json();
json_object["Object"] = json!("collectionDetails"); json_object["Object"] = json!("collectionDetails");
json_object["ReadOnly"] = json!(!self.is_writable_by_user(user_uuid, conn)); json_object["ReadOnly"] = json!(read_only);
json_object["HidePasswords"] = json!(self.hide_passwords_for_user(user_uuid, conn)); json_object["HidePasswords"] = json!(hide_passwords);
json_object json_object
} }
} }
@ -73,8 +89,8 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Collection { impl Collection {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
@ -107,10 +123,10 @@ impl Collection {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
CollectionCipher::delete_all_by_collection(&self.uuid, conn)?; CollectionCipher::delete_all_by_collection(&self.uuid, conn).await?;
CollectionUser::delete_all_by_collection(&self.uuid, conn)?; CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
db_run! { conn: { db_run! { conn: {
diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid))) diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid)))
@ -119,20 +135,20 @@ impl Collection {
}} }}
} }
pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult {
for collection in Self::find_by_organization(org_uuid, conn) { for collection in Self::find_by_organization(org_uuid, conn).await {
collection.delete(conn)?; collection.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn update_users_revision(&self, conn: &DbConn) { pub async fn update_users_revision(&self, conn: &DbConn) {
UserOrganization::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).iter().for_each(|user_org| { for user_org in UserOrganization::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() {
User::update_uuid_revision(&user_org.user_uuid, conn); User::update_uuid_revision(&user_org.user_uuid, conn).await;
}); }
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
@ -142,7 +158,7 @@ impl Collection {
}} }}
} }
pub fn find_by_user_uuid(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user_uuid(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(users_collections::table.on(
@ -167,11 +183,11 @@ impl Collection {
}} }}
} }
pub fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec<Self> {
Self::find_by_user_uuid(user_uuid, conn).into_iter().filter(|c| c.org_uuid == org_uuid).collect() Self::find_by_user_uuid(user_uuid, conn).await.into_iter().filter(|c| c.org_uuid == org_uuid).collect()
} }
pub fn find_by_organization(org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_organization(org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::org_uuid.eq(org_uuid)) .filter(collections::org_uuid.eq(org_uuid))
@ -181,7 +197,7 @@ impl Collection {
}} }}
} }
pub fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.filter(collections::uuid.eq(uuid)) .filter(collections::uuid.eq(uuid))
@ -193,7 +209,7 @@ impl Collection {
}} }}
} }
pub fn find_by_uuid_and_user(uuid: &str, user_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
collections::table collections::table
.left_join(users_collections::table.on( .left_join(users_collections::table.on(
@ -219,8 +235,8 @@ impl Collection {
}} }}
} }
pub fn is_writable_by_user(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_writable_by_user(&self, user_uuid: &str, conn: &DbConn) -> bool {
match UserOrganization::find_by_user_and_org(user_uuid, &self.org_uuid, conn) { match UserOrganization::find_by_user_and_org(user_uuid, &self.org_uuid, conn).await {
None => false, // Not in Org None => false, // Not in Org
Some(user_org) => { Some(user_org) => {
if user_org.has_full_access() { if user_org.has_full_access() {
@ -241,8 +257,8 @@ impl Collection {
} }
} }
pub fn hide_passwords_for_user(&self, user_uuid: &str, conn: &DbConn) -> bool { pub async fn hide_passwords_for_user(&self, user_uuid: &str, conn: &DbConn) -> bool {
match UserOrganization::find_by_user_and_org(user_uuid, &self.org_uuid, conn) { match UserOrganization::find_by_user_and_org(user_uuid, &self.org_uuid, conn).await {
None => true, // Not in Org None => true, // Not in Org
Some(user_org) => { Some(user_org) => {
if user_org.has_full_access() { if user_org.has_full_access() {
@ -266,7 +282,7 @@ impl Collection {
/// Database methods /// Database methods
impl CollectionUser { impl CollectionUser {
pub fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_organization_and_user_uuid(org_uuid: &str, user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid)) .filter(users_collections::user_uuid.eq(user_uuid))
@ -279,14 +295,14 @@ impl CollectionUser {
}} }}
} }
pub fn save( pub async fn save(
user_uuid: &str, user_uuid: &str,
collection_uuid: &str, collection_uuid: &str,
read_only: bool, read_only: bool,
hide_passwords: bool, hide_passwords: bool,
conn: &DbConn, conn: &DbConn,
) -> EmptyResult { ) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
@ -337,8 +353,8 @@ impl CollectionUser {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn); User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
diesel::delete( diesel::delete(
@ -351,7 +367,7 @@ impl CollectionUser {
}} }}
} }
pub fn find_by_collection(collection_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_collection(collection_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
@ -362,7 +378,7 @@ impl CollectionUser {
}} }}
} }
pub fn find_by_collection_and_user(collection_uuid: &str, user_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_collection_and_user(collection_uuid: &str, user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_collections::table users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid)) .filter(users_collections::collection_uuid.eq(collection_uuid))
@ -374,10 +390,21 @@ impl CollectionUser {
}} }}
} }
pub fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
CollectionUser::find_by_collection(collection_uuid, conn).iter().for_each(|collection| { db_run! { conn: {
User::update_uuid_revision(&collection.user_uuid, conn); users_collections::table
}); .filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.load::<CollectionUserDb>(conn)
.expect("Error loading users_collections")
.from_db()
}}
}
pub async fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() {
User::update_uuid_revision(&collection.user_uuid, conn).await;
}
db_run! { conn: { db_run! { conn: {
diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid))) diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid)))
@ -386,8 +413,8 @@ impl CollectionUser {
}} }}
} }
pub fn delete_all_by_user_and_org(user_uuid: &str, org_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user_and_org(user_uuid: &str, org_uuid: &str, conn: &DbConn) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn); let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
for user in collectionusers { for user in collectionusers {
@ -405,8 +432,8 @@ impl CollectionUser {
/// Database methods /// Database methods
impl CollectionCipher { impl CollectionCipher {
pub fn save(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn save(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn); Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
@ -435,8 +462,8 @@ impl CollectionCipher {
} }
} }
pub fn delete(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete(cipher_uuid: &str, collection_uuid: &str, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn); Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
diesel::delete( diesel::delete(
@ -449,7 +476,7 @@ impl CollectionCipher {
}} }}
} }
pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -457,7 +484,7 @@ impl CollectionCipher {
}} }}
} }
pub fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_collection(collection_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid))) diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn) .execute(conn)
@ -465,9 +492,9 @@ impl CollectionCipher {
}} }}
} }
pub fn update_users_revision(collection_uuid: &str, conn: &DbConn) { pub async fn update_users_revision(collection_uuid: &str, conn: &DbConn) {
if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn) { if let Some(collection) = Collection::find_by_uuid(collection_uuid, conn).await {
collection.update_users_revision(conn); collection.update_users_revision(conn).await;
} }
} }
} }

50
src/db/models/device.rs

@ -1,14 +1,12 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use super::User;
use crate::CONFIG; use crate::CONFIG;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "devices"] #[table_name = "devices"]
#[changeset_options(treat_none_as_null="true")] #[changeset_options(treat_none_as_null="true")]
#[belongs_to(User, foreign_key = "user_uuid")] #[primary_key(uuid, user_uuid)]
#[primary_key(uuid)]
pub struct Device { pub struct Device {
pub uuid: String, pub uuid: String,
pub created_at: NaiveDateTime, pub created_at: NaiveDateTime,
@ -89,11 +87,11 @@ impl Device {
nbf: time_now.timestamp(), nbf: time_now.timestamp(),
exp: (time_now + *DEFAULT_VALIDITY).timestamp(), exp: (time_now + *DEFAULT_VALIDITY).timestamp(),
iss: JWT_LOGIN_ISSUER.to_string(), iss: JWT_LOGIN_ISSUER.to_string(),
sub: user.uuid.to_string(), sub: user.uuid.clone(),
premium: true, premium: true,
name: user.name.to_string(), name: user.name.clone(),
email: user.email.to_string(), email: user.email.clone(),
email_verified: !CONFIG.mail_enabled() || user.verified_at.is_some(), email_verified: !CONFIG.mail_enabled() || user.verified_at.is_some(),
orgowner, orgowner,
@ -101,8 +99,8 @@ impl Device {
orguser, orguser,
orgmanager, orgmanager,
sstamp: user.security_stamp.to_string(), sstamp: user.security_stamp.clone(),
device: self.uuid.to_string(), device: self.uuid.clone(),
scope, scope,
amr: vec!["Application".into()], amr: vec!["Application".into()],
}; };
@ -118,7 +116,7 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Device { impl Device {
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
@ -131,39 +129,33 @@ impl Device {
postgresql { postgresql {
let value = DeviceDb::to_db(self); let value = DeviceDb::to_db(self);
crate::util::retry( crate::util::retry(
|| diesel::insert_into(devices::table).values(&value).on_conflict(devices::uuid).do_update().set(&value).execute(conn), || diesel::insert_into(devices::table).values(&value).on_conflict((devices::uuid, devices::user_uuid)).do_update().set(&value).execute(conn),
10, 10,
).map_res("Error saving device") ).map_res("Error saving device")
} }
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(devices::table.filter(devices::uuid.eq(self.uuid))) diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing device") .map_res("Error removing devices for user")
}} }}
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &str, conn: &DbConn) -> Option<Self> {
for device in Self::find_by_user(user_uuid, conn) {
device.delete(conn)?;
}
Ok(())
}
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::uuid.eq(uuid)) .filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid))
.first::<DeviceDb>(conn) .first::<DeviceDb>(conn)
.ok() .ok()
.from_db() .from_db()
}} }}
} }
pub fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::refresh_token.eq(refresh_token)) .filter(devices::refresh_token.eq(refresh_token))
@ -173,17 +165,7 @@ impl Device {
}} }}
} }
pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_latest_active_by_user(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.load::<DeviceDb>(conn)
.expect("Error loading devices")
.from_db()
}}
}
pub fn find_latest_active_by_user(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
devices::table devices::table
.filter(devices::user_uuid.eq(user_uuid)) .filter(devices::user_uuid.eq(user_uuid))

45
src/db/models/emergency_access.rs

@ -4,10 +4,9 @@ use serde_json::Value;
use super::User; use super::User;
db_object! { db_object! {
#[derive(Debug, Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Debug, Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "emergency_access"] #[table_name = "emergency_access"]
#[changeset_options(treat_none_as_null="true")] #[changeset_options(treat_none_as_null="true")]
#[belongs_to(User, foreign_key = "grantor_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct EmergencyAccess { pub struct EmergencyAccess {
pub uuid: String, pub uuid: String,
@ -73,8 +72,8 @@ impl EmergencyAccess {
}) })
} }
pub fn to_json_grantor_details(&self, conn: &DbConn) -> Value { pub async fn to_json_grantor_details(&self, conn: &DbConn) -> Value {
let grantor_user = User::find_by_uuid(&self.grantor_uuid, conn).expect("Grantor user not found."); let grantor_user = User::find_by_uuid(&self.grantor_uuid, conn).await.expect("Grantor user not found.");
json!({ json!({
"Id": self.uuid, "Id": self.uuid,
@ -89,11 +88,11 @@ impl EmergencyAccess {
} }
#[allow(clippy::manual_map)] #[allow(clippy::manual_map)]
pub fn to_json_grantee_details(&self, conn: &DbConn) -> Value { pub async fn to_json_grantee_details(&self, conn: &DbConn) -> Value {
let grantee_user = if let Some(grantee_uuid) = self.grantee_uuid.as_deref() { let grantee_user = if let Some(grantee_uuid) = self.grantee_uuid.as_deref() {
Some(User::find_by_uuid(grantee_uuid, conn).expect("Grantee user not found.")) Some(User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found."))
} else if let Some(email) = self.email.as_deref() { } else if let Some(email) = self.email.as_deref() {
Some(User::find_by_mail(email, conn).expect("Grantee user not found.")) Some(User::find_by_mail(email, conn).await.expect("Grantee user not found."))
} else { } else {
None None
}; };
@ -155,8 +154,8 @@ use crate::api::EmptyResult;
use crate::error::MapResult; use crate::error::MapResult;
impl EmergencyAccess { impl EmergencyAccess {
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn); User::update_uuid_revision(&self.grantor_uuid, conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
@ -190,18 +189,18 @@ impl EmergencyAccess {
} }
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
for ea in Self::find_all_by_grantor_uuid(user_uuid, conn) { for ea in Self::find_all_by_grantor_uuid(user_uuid, conn).await {
ea.delete(conn)?; ea.delete(conn).await?;
} }
for ea in Self::find_all_by_grantee_uuid(user_uuid, conn) { for ea in Self::find_all_by_grantee_uuid(user_uuid, conn).await {
ea.delete(conn)?; ea.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn); User::update_uuid_revision(&self.grantor_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid))) diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid)))
@ -210,7 +209,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
@ -219,7 +218,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_by_grantor_uuid_and_grantee_uuid_or_email( pub async fn find_by_grantor_uuid_and_grantee_uuid_or_email(
grantor_uuid: &str, grantor_uuid: &str,
grantee_uuid: &str, grantee_uuid: &str,
email: &str, email: &str,
@ -234,7 +233,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_all_recoveries(conn: &DbConn) -> Vec<Self> { pub async fn find_all_recoveries(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32)) .filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
@ -242,7 +241,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_by_uuid_and_grantor_uuid(uuid: &str, grantor_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_grantor_uuid(uuid: &str, grantor_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::uuid.eq(uuid)) .filter(emergency_access::uuid.eq(uuid))
@ -252,7 +251,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_all_by_grantee_uuid(grantee_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_all_by_grantee_uuid(grantee_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid)) .filter(emergency_access::grantee_uuid.eq(grantee_uuid))
@ -260,7 +259,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> { pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::email.eq(grantee_email)) .filter(emergency_access::email.eq(grantee_email))
@ -270,7 +269,7 @@ impl EmergencyAccess {
}} }}
} }
pub fn find_all_by_grantor_uuid(grantor_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_all_by_grantor_uuid(grantor_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
emergency_access::table emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid)) .filter(emergency_access::grantor_uuid.eq(grantor_uuid))

32
src/db/models/favorite.rs

@ -1,10 +1,8 @@
use super::{Cipher, User}; use super::User;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations)] #[derive(Identifiable, Queryable, Insertable)]
#[table_name = "favorites"] #[table_name = "favorites"]
#[belongs_to(User, foreign_key = "user_uuid")]
#[belongs_to(Cipher, foreign_key = "cipher_uuid")]
#[primary_key(user_uuid, cipher_uuid)] #[primary_key(user_uuid, cipher_uuid)]
pub struct Favorite { pub struct Favorite {
pub user_uuid: String, pub user_uuid: String,
@ -19,7 +17,7 @@ use crate::error::MapResult;
impl Favorite { impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user. // Returns whether the specified cipher is a favorite of the specified user.
pub fn is_favorite(cipher_uuid: &str, user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_favorite(cipher_uuid: &str, user_uuid: &str, conn: &DbConn) -> bool {
db_run! { conn: { db_run! { conn: {
let query = favorites::table let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid)) .filter(favorites::cipher_uuid.eq(cipher_uuid))
@ -31,11 +29,11 @@ impl Favorite {
} }
// Sets whether the specified cipher is a favorite of the specified user. // Sets whether the specified cipher is a favorite of the specified user.
pub fn set_favorite(favorite: bool, cipher_uuid: &str, user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn set_favorite(favorite: bool, cipher_uuid: &str, user_uuid: &str, conn: &DbConn) -> EmptyResult {
let (old, new) = (Self::is_favorite(cipher_uuid, user_uuid, conn), favorite); let (old, new) = (Self::is_favorite(cipher_uuid, user_uuid, conn).await, favorite);
match (old, new) { match (old, new) {
(false, true) => { (false, true) => {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
diesel::insert_into(favorites::table) diesel::insert_into(favorites::table)
.values(( .values((
@ -47,7 +45,7 @@ impl Favorite {
}} }}
} }
(true, false) => { (true, false) => {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: { db_run! { conn: {
diesel::delete( diesel::delete(
favorites::table favorites::table
@ -64,7 +62,7 @@ impl Favorite {
} }
// Delete all favorite entries associated with the specified cipher. // Delete all favorite entries associated with the specified cipher.
pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid))) diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -73,11 +71,23 @@ impl Favorite {
} }
// Delete all favorite entries associated with the specified user. // Delete all favorite entries associated with the specified user.
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid))) diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
.map_res("Error removing favorites by user") .map_res("Error removing favorites by user")
}} }}
} }
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &str, conn: &DbConn) -> Vec<String> {
db_run! { conn: {
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))
.select(favorites::cipher_uuid)
.load::<String>(conn)
.unwrap_or_default()
}}
}
} }

54
src/db/models/folder.rs

@ -1,12 +1,11 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use serde_json::Value; use serde_json::Value;
use super::{Cipher, User}; use super::User;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "folders"] #[table_name = "folders"]
#[belongs_to(User, foreign_key = "user_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct Folder { pub struct Folder {
pub uuid: String, pub uuid: String,
@ -16,10 +15,8 @@ db_object! {
pub name: String, pub name: String,
} }
#[derive(Identifiable, Queryable, Insertable, Associations)] #[derive(Identifiable, Queryable, Insertable)]
#[table_name = "folders_ciphers"] #[table_name = "folders_ciphers"]
#[belongs_to(Cipher, foreign_key = "cipher_uuid")]
#[belongs_to(Folder, foreign_key = "folder_uuid")]
#[primary_key(cipher_uuid, folder_uuid)] #[primary_key(cipher_uuid, folder_uuid)]
pub struct FolderCipher { pub struct FolderCipher {
pub cipher_uuid: String, pub cipher_uuid: String,
@ -70,8 +67,8 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Folder { impl Folder {
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn); User::update_uuid_revision(&self.user_uuid, conn).await;
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
@ -105,9 +102,9 @@ impl Folder {
} }
} }
pub fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn); User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn)?; FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
db_run! { conn: { db_run! { conn: {
diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid))) diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid)))
@ -116,14 +113,14 @@ impl Folder {
}} }}
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
for folder in Self::find_by_user(user_uuid, conn) { for folder in Self::find_by_user(user_uuid, conn).await {
folder.delete(conn)?; folder.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
folders::table folders::table
.filter(folders::uuid.eq(uuid)) .filter(folders::uuid.eq(uuid))
@ -133,7 +130,7 @@ impl Folder {
}} }}
} }
pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
folders::table folders::table
.filter(folders::user_uuid.eq(user_uuid)) .filter(folders::user_uuid.eq(user_uuid))
@ -145,7 +142,7 @@ impl Folder {
} }
impl FolderCipher { impl FolderCipher {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
// Not checking for ForeignKey Constraints here. // Not checking for ForeignKey Constraints here.
@ -167,7 +164,7 @@ impl FolderCipher {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete( diesel::delete(
folders_ciphers::table folders_ciphers::table
@ -179,7 +176,7 @@ impl FolderCipher {
}} }}
} }
pub fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_cipher(cipher_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn) .execute(conn)
@ -187,7 +184,7 @@ impl FolderCipher {
}} }}
} }
pub fn delete_all_by_folder(folder_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_folder(folder_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid))) diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn) .execute(conn)
@ -195,7 +192,7 @@ impl FolderCipher {
}} }}
} }
pub fn find_by_folder_and_cipher(folder_uuid: &str, cipher_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_folder_and_cipher(folder_uuid: &str, cipher_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
@ -206,7 +203,7 @@ impl FolderCipher {
}} }}
} }
pub fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_folder(folder_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
folders_ciphers::table folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid)) .filter(folders_ciphers::folder_uuid.eq(folder_uuid))
@ -215,4 +212,17 @@ impl FolderCipher {
.from_db() .from_db()
}} }}
} }
/// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<(String, String)> {
db_run! { conn: {
folders_ciphers::table
.inner_join(folders::table)
.filter(folders::user_uuid.eq(user_uuid))
.select(folders_ciphers::all_columns)
.load::<(String, String)>(conn)
.unwrap_or_default()
}}
}
} }

41
src/db/models/org_policy.rs

@ -6,12 +6,11 @@ use crate::db::DbConn;
use crate::error::MapResult; use crate::error::MapResult;
use crate::util::UpCase; use crate::util::UpCase;
use super::{Organization, UserOrgStatus, UserOrgType, UserOrganization}; use super::{UserOrgStatus, UserOrgType, UserOrganization};
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "org_policies"] #[table_name = "org_policies"]
#[belongs_to(Organization, foreign_key = "org_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct OrgPolicy { pub struct OrgPolicy {
pub uuid: String, pub uuid: String,
@ -22,7 +21,7 @@ db_object! {
} }
} }
#[derive(Copy, Clone, PartialEq, num_derive::FromPrimitive)] #[derive(Copy, Clone, Eq, PartialEq, num_derive::FromPrimitive)]
pub enum OrgPolicyType { pub enum OrgPolicyType {
TwoFactorAuthentication = 0, TwoFactorAuthentication = 0,
MasterPassword = 1, MasterPassword = 1,
@ -72,7 +71,7 @@ impl OrgPolicy {
/// Database methods /// Database methods
impl OrgPolicy { impl OrgPolicy {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(org_policies::table) match diesel::replace_into(org_policies::table)
@ -115,7 +114,7 @@ impl OrgPolicy {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid))) diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
@ -123,7 +122,7 @@ impl OrgPolicy {
}} }}
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.filter(org_policies::uuid.eq(uuid)) .filter(org_policies::uuid.eq(uuid))
@ -133,7 +132,7 @@ impl OrgPolicy {
}} }}
} }
pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
@ -143,7 +142,7 @@ impl OrgPolicy {
}} }}
} }
pub fn find_confirmed_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.inner_join( .inner_join(
@ -161,7 +160,7 @@ impl OrgPolicy {
}} }}
} }
pub fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Option<Self> { pub async fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
org_policies::table org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid)) .filter(org_policies::org_uuid.eq(org_uuid))
@ -172,7 +171,7 @@ impl OrgPolicy {
}} }}
} }
pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid))) diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn) .execute(conn)
@ -183,12 +182,12 @@ impl OrgPolicy {
/// Returns true if the user belongs to an org that has enabled the specified policy type, /// Returns true if the user belongs to an org that has enabled the specified policy type,
/// and the user is not an owner or admin of that org. This is only useful for checking /// and the user is not an owner or admin of that org. This is only useful for checking
/// applicability of policy types that have these particular semantics. /// applicability of policy types that have these particular semantics.
pub fn is_applicable_to_user(user_uuid: &str, policy_type: OrgPolicyType, conn: &DbConn) -> bool { pub async fn is_applicable_to_user(user_uuid: &str, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
// TODO: Should check confirmed and accepted users // TODO: Should check confirmed and accepted users
for policy in OrgPolicy::find_confirmed_by_user(user_uuid, conn) { for policy in OrgPolicy::find_confirmed_by_user(user_uuid, conn).await {
if policy.enabled && policy.has_type(policy_type) { if policy.enabled && policy.has_type(policy_type) {
let org_uuid = &policy.org_uuid; let org_uuid = &policy.org_uuid;
if let Some(user) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn) { if let Some(user) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn).await {
if user.atype < UserOrgType::Admin { if user.atype < UserOrgType::Admin {
return true; return true;
} }
@ -200,11 +199,11 @@ impl OrgPolicy {
/// Returns true if the user belongs to an org that has enabled the `DisableHideEmail` /// Returns true if the user belongs to an org that has enabled the `DisableHideEmail`
/// option of the `Send Options` policy, and the user is not an owner or admin of that org. /// option of the `Send Options` policy, and the user is not an owner or admin of that org.
pub fn is_hide_email_disabled(user_uuid: &str, conn: &DbConn) -> bool { pub async fn is_hide_email_disabled(user_uuid: &str, conn: &DbConn) -> bool {
for policy in OrgPolicy::find_confirmed_by_user(user_uuid, conn) { for policy in OrgPolicy::find_confirmed_by_user(user_uuid, conn).await {
if policy.enabled && policy.has_type(OrgPolicyType::SendOptions) { if policy.enabled && policy.has_type(OrgPolicyType::SendOptions) {
let org_uuid = &policy.org_uuid; let org_uuid = &policy.org_uuid;
if let Some(user) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn) { if let Some(user) = UserOrganization::find_by_user_and_org(user_uuid, org_uuid, conn).await {
if user.atype < UserOrgType::Admin { if user.atype < UserOrgType::Admin {
match serde_json::from_str::<UpCase<SendOptionsPolicyData>>(&policy.data) { match serde_json::from_str::<UpCase<SendOptionsPolicyData>>(&policy.data) {
Ok(opts) => { Ok(opts) => {
@ -220,12 +219,4 @@ impl OrgPolicy {
} }
false false
} }
/*pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error deleting twofactors")
}}
}*/
} }

96
src/db/models/organization.rs

@ -193,10 +193,10 @@ use crate::error::MapResult;
/// Database methods /// Database methods
impl Organization { impl Organization {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
UserOrganization::find_by_org(&self.uuid, conn).iter().for_each(|user_org| { for user_org in UserOrganization::find_by_org(&self.uuid, conn).await.iter() {
User::update_uuid_revision(&user_org.user_uuid, conn); User::update_uuid_revision(&user_org.user_uuid, conn).await;
}); }
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
@ -230,13 +230,13 @@ impl Organization {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection}; use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn)?; Cipher::delete_all_by_organization(&self.uuid, conn).await?;
Collection::delete_all_by_organization(&self.uuid, conn)?; Collection::delete_all_by_organization(&self.uuid, conn).await?;
UserOrganization::delete_all_by_organization(&self.uuid, conn)?; UserOrganization::delete_all_by_organization(&self.uuid, conn).await?;
OrgPolicy::delete_all_by_organization(&self.uuid, conn)?; OrgPolicy::delete_all_by_organization(&self.uuid, conn).await?;
db_run! { conn: { db_run! { conn: {
diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid))) diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid)))
@ -245,7 +245,7 @@ impl Organization {
}} }}
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table organizations::table
.filter(organizations::uuid.eq(uuid)) .filter(organizations::uuid.eq(uuid))
@ -254,7 +254,7 @@ impl Organization {
}} }}
} }
pub fn get_all(conn: &DbConn) -> Vec<Self> { pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table.load::<OrganizationDb>(conn).expect("Error loading organizations").from_db() organizations::table.load::<OrganizationDb>(conn).expect("Error loading organizations").from_db()
}} }}
@ -262,8 +262,8 @@ impl Organization {
} }
impl UserOrganization { impl UserOrganization {
pub fn to_json(&self, conn: &DbConn) -> Value { pub async fn to_json(&self, conn: &DbConn) -> Value {
let org = Organization::find_by_uuid(&self.org_uuid, conn).unwrap(); let org = Organization::find_by_uuid(&self.org_uuid, conn).await.unwrap();
json!({ json!({
"Id": self.org_uuid, "Id": self.org_uuid,
@ -322,8 +322,8 @@ impl UserOrganization {
}) })
} }
pub fn to_json_user_details(&self, conn: &DbConn) -> Value { pub async fn to_json_user_details(&self, conn: &DbConn) -> Value {
let user = User::find_by_uuid(&self.user_uuid, conn).unwrap(); let user = User::find_by_uuid(&self.user_uuid, conn).await.unwrap();
json!({ json!({
"Id": self.uuid, "Id": self.uuid,
@ -347,11 +347,12 @@ impl UserOrganization {
}) })
} }
pub fn to_json_details(&self, conn: &DbConn) -> Value { pub async fn to_json_details(&self, conn: &DbConn) -> Value {
let coll_uuids = if self.access_all { let coll_uuids = if self.access_all {
vec![] // If we have complete access, no need to fill the array vec![] // If we have complete access, no need to fill the array
} else { } else {
let collections = CollectionUser::find_by_organization_and_user_uuid(&self.org_uuid, &self.user_uuid, conn); let collections =
CollectionUser::find_by_organization_and_user_uuid(&self.org_uuid, &self.user_uuid, conn).await;
collections collections
.iter() .iter()
.map(|c| { .map(|c| {
@ -376,8 +377,8 @@ impl UserOrganization {
"Object": "organizationUserDetails", "Object": "organizationUserDetails",
}) })
} }
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn); User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
@ -410,10 +411,10 @@ impl UserOrganization {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn); User::update_uuid_revision(&self.user_uuid, conn).await;
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn)?; CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
db_run! { conn: { db_run! { conn: {
diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid))) diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid)))
@ -422,23 +423,23 @@ impl UserOrganization {
}} }}
} }
pub fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_organization(org_uuid: &str, conn: &DbConn) -> EmptyResult {
for user_org in Self::find_by_org(org_uuid, conn) { for user_org in Self::find_by_org(org_uuid, conn).await {
user_org.delete(conn)?; user_org.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
for user_org in Self::find_any_state_by_user(user_uuid, conn) { for user_org in Self::find_any_state_by_user(user_uuid, conn).await {
user_org.delete(conn)?; user_org.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn find_by_email_and_org(email: &str, org_id: &str, conn: &DbConn) -> Option<UserOrganization> { pub async fn find_by_email_and_org(email: &str, org_id: &str, conn: &DbConn) -> Option<UserOrganization> {
if let Some(user) = super::User::find_by_mail(email, conn) { if let Some(user) = super::User::find_by_mail(email, conn).await {
if let Some(user_org) = UserOrganization::find_by_user_and_org(&user.uuid, org_id, conn) { if let Some(user_org) = UserOrganization::find_by_user_and_org(&user.uuid, org_id, conn).await {
return Some(user_org); return Some(user_org);
} }
} }
@ -458,7 +459,7 @@ impl UserOrganization {
(self.access_all || self.atype >= UserOrgType::Admin) && self.has_status(UserOrgStatus::Confirmed) (self.access_all || self.atype >= UserOrgType::Admin) && self.has_status(UserOrgStatus::Confirmed)
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::uuid.eq(uuid)) .filter(users_organizations::uuid.eq(uuid))
@ -467,7 +468,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid_and_org(uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::uuid.eq(uuid)) .filter(users_organizations::uuid.eq(uuid))
@ -477,7 +478,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_confirmed_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_confirmed_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -487,7 +488,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_invited_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_invited_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -497,7 +498,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_any_state_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_any_state_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -506,7 +507,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -515,7 +516,7 @@ impl UserOrganization {
}} }}
} }
pub fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 { pub async fn count_by_org(org_uuid: &str, conn: &DbConn) -> i64 {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -526,7 +527,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org_and_type(org_uuid: &str, atype: i32, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -536,7 +537,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_user_and_org(user_uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_org(user_uuid: &str, org_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid)) .filter(users_organizations::user_uuid.eq(user_uuid))
@ -546,7 +547,16 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_user_and_policy(user_uuid: &str, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<UserOrganizationDb>(conn)
.expect("Error loading user organizations").from_db()
}}
}
pub async fn find_by_user_and_policy(user_uuid: &str, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.inner_join( .inner_join(
@ -565,7 +575,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_cipher_and_org(cipher_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_cipher_and_org(cipher_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))
@ -587,7 +597,7 @@ impl UserOrganization {
}} }}
} }
pub fn find_by_collection_and_org(collection_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_collection_and_org(collection_uuid: &str, org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
users_organizations::table users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid)) .filter(users_organizations::org_uuid.eq(org_uuid))

50
src/db/models/send.rs

@ -1,14 +1,12 @@
use chrono::{NaiveDateTime, Utc}; use chrono::{NaiveDateTime, Utc};
use serde_json::Value; use serde_json::Value;
use super::{Organization, User}; use super::User;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "sends"] #[table_name = "sends"]
#[changeset_options(treat_none_as_null="true")] #[changeset_options(treat_none_as_null="true")]
#[belongs_to(User, foreign_key = "user_uuid")]
#[belongs_to(Organization, foreign_key = "organization_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct Send { pub struct Send {
pub uuid: String, pub uuid: String,
@ -103,7 +101,7 @@ impl Send {
} }
} }
pub fn creator_identifier(&self, conn: &DbConn) -> Option<String> { pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email { if let Some(hide_email) = self.hide_email {
if hide_email { if hide_email {
return None; return None;
@ -111,7 +109,7 @@ impl Send {
} }
if let Some(user_uuid) = &self.user_uuid { if let Some(user_uuid) = &self.user_uuid {
if let Some(user) = User::find_by_uuid(user_uuid, conn) { if let Some(user) = User::find_by_uuid(user_uuid, conn).await {
return Some(user.email); return Some(user.email);
} }
} }
@ -150,7 +148,7 @@ impl Send {
}) })
} }
pub fn to_json_access(&self, conn: &DbConn) -> Value { pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date; use crate::util::format_date;
let data: Value = serde_json::from_str(&self.data).unwrap_or_default(); let data: Value = serde_json::from_str(&self.data).unwrap_or_default();
@ -164,7 +162,7 @@ impl Send {
"File": if self.atype == SendType::File as i32 { Some(&data) } else { None }, "File": if self.atype == SendType::File as i32 { Some(&data) } else { None },
"ExpirationDate": self.expiration_date.as_ref().map(format_date), "ExpirationDate": self.expiration_date.as_ref().map(format_date),
"CreatorIdentifier": self.creator_identifier(conn), "CreatorIdentifier": self.creator_identifier(conn).await,
"Object": "send-access", "Object": "send-access",
}) })
} }
@ -176,8 +174,8 @@ use crate::api::EmptyResult;
use crate::error::MapResult; use crate::error::MapResult;
impl Send { impl Send {
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
self.revision_date = Utc::now().naive_utc(); self.revision_date = Utc::now().naive_utc();
db_run! { conn: db_run! { conn:
@ -211,8 +209,8 @@ impl Send {
} }
} }
pub fn delete(&self, conn: &DbConn) -> EmptyResult { pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn); self.update_users_revision(conn).await;
if self.atype == SendType::File as i32 { if self.atype == SendType::File as i32 {
std::fs::remove_dir_all(std::path::Path::new(&crate::CONFIG.sends_folder()).join(&self.uuid)).ok(); std::fs::remove_dir_all(std::path::Path::new(&crate::CONFIG.sends_folder()).join(&self.uuid)).ok();
@ -226,17 +224,17 @@ impl Send {
} }
/// Purge all sends that are past their deletion date. /// Purge all sends that are past their deletion date.
pub fn purge(conn: &DbConn) { pub async fn purge(conn: &DbConn) {
for send in Self::find_by_past_deletion_date(conn) { for send in Self::find_by_past_deletion_date(conn).await {
send.delete(conn).ok(); send.delete(conn).await.ok();
} }
} }
pub fn update_users_revision(&self, conn: &DbConn) -> Vec<String> { pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<String> {
let mut user_uuids = Vec::new(); let mut user_uuids = Vec::new();
match &self.user_uuid { match &self.user_uuid {
Some(user_uuid) => { Some(user_uuid) => {
User::update_uuid_revision(user_uuid, conn); User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone()) user_uuids.push(user_uuid.clone())
} }
None => { None => {
@ -246,14 +244,14 @@ impl Send {
user_uuids user_uuids
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
for send in Self::find_by_user(user_uuid, conn) { for send in Self::find_by_user(user_uuid, conn).await {
send.delete(conn)?; send.delete(conn).await?;
} }
Ok(()) Ok(())
} }
pub fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid; use uuid::Uuid;
@ -267,10 +265,10 @@ impl Send {
Err(_) => return None, Err(_) => return None,
}; };
Self::find_by_uuid(&uuid, conn) Self::find_by_uuid(&uuid, conn).await
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! {conn: { db_run! {conn: {
sends::table sends::table
.filter(sends::uuid.eq(uuid)) .filter(sends::uuid.eq(uuid))
@ -280,7 +278,7 @@ impl Send {
}} }}
} }
pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
sends::table sends::table
.filter(sends::user_uuid.eq(user_uuid)) .filter(sends::user_uuid.eq(user_uuid))
@ -288,7 +286,7 @@ impl Send {
}} }}
} }
pub fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_org(org_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
sends::table sends::table
.filter(sends::organization_uuid.eq(org_uuid)) .filter(sends::organization_uuid.eq(org_uuid))
@ -296,7 +294,7 @@ impl Send {
}} }}
} }
pub fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> { pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc(); let now = Utc::now().naive_utc();
db_run! {conn: { db_run! {conn: {
sends::table sends::table

26
src/db/models/two_factor.rs

@ -2,12 +2,9 @@ use serde_json::Value;
use crate::{api::EmptyResult, db::DbConn, error::MapResult}; use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use super::User;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "twofactor"] #[table_name = "twofactor"]
#[belongs_to(User, foreign_key = "user_uuid")]
#[primary_key(uuid)] #[primary_key(uuid)]
pub struct TwoFactor { pub struct TwoFactor {
pub uuid: String, pub uuid: String,
@ -71,7 +68,7 @@ impl TwoFactor {
/// Database methods /// Database methods
impl TwoFactor { impl TwoFactor {
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: db_run! { conn:
sqlite, mysql { sqlite, mysql {
match diesel::replace_into(twofactor::table) match diesel::replace_into(twofactor::table)
@ -110,7 +107,7 @@ impl TwoFactor {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid))) diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn) .execute(conn)
@ -118,7 +115,7 @@ impl TwoFactor {
}} }}
} }
pub fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> { pub async fn find_by_user(user_uuid: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
@ -129,7 +126,7 @@ impl TwoFactor {
}} }}
} }
pub fn find_by_user_and_type(user_uuid: &str, atype: i32, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_type(user_uuid: &str, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid)) .filter(twofactor::user_uuid.eq(user_uuid))
@ -140,7 +137,7 @@ impl TwoFactor {
}} }}
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid))) diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)
@ -148,7 +145,7 @@ impl TwoFactor {
}} }}
} }
pub fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult { pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
let u2f_factors = db_run! { conn: { let u2f_factors = db_run! { conn: {
twofactor::table twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32)) .filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
@ -157,7 +154,7 @@ impl TwoFactor {
.from_db() .from_db()
}}; }};
use crate::api::core::two_factor::u2f::U2FRegistration; use crate::api::core::two_factor::webauthn::U2FRegistration;
use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration}; use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration};
use webauthn_rs::proto::*; use webauthn_rs::proto::*;
@ -168,7 +165,7 @@ impl TwoFactor {
continue; continue;
} }
let (_, mut webauthn_regs) = get_webauthn_registrations(&u2f.user_uuid, conn)?; let (_, mut webauthn_regs) = get_webauthn_registrations(&u2f.user_uuid, conn).await?;
// If the user already has webauthn registrations saved, don't overwrite them // If the user already has webauthn registrations saved, don't overwrite them
if !webauthn_regs.is_empty() { if !webauthn_regs.is_empty() {
@ -207,10 +204,11 @@ impl TwoFactor {
} }
u2f.data = serde_json::to_string(&regs)?; u2f.data = serde_json::to_string(&regs)?;
u2f.save(conn)?; u2f.save(conn).await?;
TwoFactor::new(u2f.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&webauthn_regs)?) TwoFactor::new(u2f.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&webauthn_regs)?)
.save(conn)?; .save(conn)
.await?;
} }
Ok(()) Ok(())

25
src/db/models/two_factor_incomplete.rs

@ -2,12 +2,9 @@ use chrono::{NaiveDateTime, Utc};
use crate::{api::EmptyResult, auth::ClientIp, db::DbConn, error::MapResult, CONFIG}; use crate::{api::EmptyResult, auth::ClientIp, db::DbConn, error::MapResult, CONFIG};
use super::User;
db_object! { db_object! {
#[derive(Identifiable, Queryable, Insertable, Associations, AsChangeset)] #[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[table_name = "twofactor_incomplete"] #[table_name = "twofactor_incomplete"]
#[belongs_to(User, foreign_key = "user_uuid")]
#[primary_key(user_uuid, device_uuid)] #[primary_key(user_uuid, device_uuid)]
pub struct TwoFactorIncomplete { pub struct TwoFactorIncomplete {
pub user_uuid: String, pub user_uuid: String,
@ -22,7 +19,7 @@ db_object! {
} }
impl TwoFactorIncomplete { impl TwoFactorIncomplete {
pub fn mark_incomplete( pub async fn mark_incomplete(
user_uuid: &str, user_uuid: &str,
device_uuid: &str, device_uuid: &str,
device_name: &str, device_name: &str,
@ -36,7 +33,7 @@ impl TwoFactorIncomplete {
// Don't update the data for an existing user/device pair, since that // Don't update the data for an existing user/device pair, since that
// would allow an attacker to arbitrarily delay notifications by // would allow an attacker to arbitrarily delay notifications by
// sending repeated 2FA attempts to reset the timer. // sending repeated 2FA attempts to reset the timer.
let existing = Self::find_by_user_and_device(user_uuid, device_uuid, conn); let existing = Self::find_by_user_and_device(user_uuid, device_uuid, conn).await;
if existing.is_some() { if existing.is_some() {
return Ok(()); return Ok(());
} }
@ -55,15 +52,15 @@ impl TwoFactorIncomplete {
}} }}
} }
pub fn mark_complete(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn mark_complete(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> EmptyResult {
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return Ok(()); return Ok(());
} }
Self::delete_by_user_and_device(user_uuid, device_uuid, conn) Self::delete_by_user_and_device(user_uuid, device_uuid, conn).await
} }
pub fn find_by_user_and_device(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_user_and_device(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: { db_run! { conn: {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
@ -74,7 +71,7 @@ impl TwoFactorIncomplete {
}} }}
} }
pub fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> { pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
twofactor_incomplete::table twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt)) .filter(twofactor_incomplete::login_time.lt(dt))
@ -84,11 +81,11 @@ impl TwoFactorIncomplete {
}} }}
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
Self::delete_by_user_and_device(&self.user_uuid, &self.device_uuid, conn) Self::delete_by_user_and_device(&self.user_uuid, &self.device_uuid, conn).await
} }
pub fn delete_by_user_and_device(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_by_user_and_device(user_uuid: &str, device_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor_incomplete::table diesel::delete(twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid)) .filter(twofactor_incomplete::user_uuid.eq(user_uuid))
@ -98,7 +95,7 @@ impl TwoFactorIncomplete {
}} }}
} }
pub fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult { pub async fn delete_all_by_user(user_uuid: &str, conn: &DbConn) -> EmptyResult {
db_run! { conn: { db_run! { conn: {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid))) diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn) .execute(conn)

80
src/db/models/user.rs

@ -171,7 +171,7 @@ impl User {
pub fn set_stamp_exception(&mut self, route_exception: Vec<String>) { pub fn set_stamp_exception(&mut self, route_exception: Vec<String>) {
let stamp_exception = UserStampException { let stamp_exception = UserStampException {
routes: route_exception, routes: route_exception,
security_stamp: self.security_stamp.to_string(), security_stamp: self.security_stamp.clone(),
expire: (Utc::now().naive_utc() + Duration::minutes(2)).timestamp(), expire: (Utc::now().naive_utc() + Duration::minutes(2)).timestamp(),
}; };
self.stamp_exception = Some(serde_json::to_string(&stamp_exception).unwrap_or_default()); self.stamp_exception = Some(serde_json::to_string(&stamp_exception).unwrap_or_default());
@ -192,12 +192,20 @@ use crate::db::DbConn;
use crate::api::EmptyResult; use crate::api::EmptyResult;
use crate::error::MapResult; use crate::error::MapResult;
use futures::{stream, stream::StreamExt};
/// Database methods /// Database methods
impl User { impl User {
pub fn to_json(&self, conn: &DbConn) -> Value { pub async fn to_json(&self, conn: &DbConn) -> Value {
let orgs = UserOrganization::find_confirmed_by_user(&self.uuid, conn); let orgs_json = stream::iter(UserOrganization::find_confirmed_by_user(&self.uuid, conn).await)
let orgs_json: Vec<Value> = orgs.iter().map(|c| c.to_json(conn)).collect(); .then(|c| async {
let twofactor_enabled = !TwoFactor::find_by_user(&self.uuid, conn).is_empty(); let c = c; // Move out this single variable
c.to_json(conn).await
})
.collect::<Vec<Value>>()
.await;
let twofactor_enabled = !TwoFactor::find_by_user(&self.uuid, conn).await.is_empty();
// TODO: Might want to save the status field in the DB // TODO: Might want to save the status field in the DB
let status = if self.password_hash.is_empty() { let status = if self.password_hash.is_empty() {
@ -227,7 +235,7 @@ impl User {
}) })
} }
pub fn save(&mut self, conn: &DbConn) -> EmptyResult { pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
if self.email.trim().is_empty() { if self.email.trim().is_empty() {
err!("User email can't be empty") err!("User email can't be empty")
} }
@ -265,26 +273,26 @@ impl User {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
for user_org in UserOrganization::find_confirmed_by_user(&self.uuid, conn) { for user_org in UserOrganization::find_confirmed_by_user(&self.uuid, conn).await {
if user_org.atype == UserOrgType::Owner { if user_org.atype == UserOrgType::Owner {
let owner_type = UserOrgType::Owner as i32; let owner_type = UserOrgType::Owner as i32;
if UserOrganization::find_by_org_and_type(&user_org.org_uuid, owner_type, conn).len() <= 1 { if UserOrganization::find_by_org_and_type(&user_org.org_uuid, owner_type, conn).await.len() <= 1 {
err!("Can't delete last owner") err!("Can't delete last owner")
} }
} }
} }
Send::delete_all_by_user(&self.uuid, conn)?; Send::delete_all_by_user(&self.uuid, conn).await?;
EmergencyAccess::delete_all_by_user(&self.uuid, conn)?; EmergencyAccess::delete_all_by_user(&self.uuid, conn).await?;
UserOrganization::delete_all_by_user(&self.uuid, conn)?; UserOrganization::delete_all_by_user(&self.uuid, conn).await?;
Cipher::delete_all_by_user(&self.uuid, conn)?; Cipher::delete_all_by_user(&self.uuid, conn).await?;
Favorite::delete_all_by_user(&self.uuid, conn)?; Favorite::delete_all_by_user(&self.uuid, conn).await?;
Folder::delete_all_by_user(&self.uuid, conn)?; Folder::delete_all_by_user(&self.uuid, conn).await?;
Device::delete_all_by_user(&self.uuid, conn)?; Device::delete_all_by_user(&self.uuid, conn).await?;
TwoFactor::delete_all_by_user(&self.uuid, conn)?; TwoFactor::delete_all_by_user(&self.uuid, conn).await?;
TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn)?; TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn).await?;
Invitation::take(&self.email, conn); // Delete invitation if any Invitation::take(&self.email, conn).await; // Delete invitation if any
db_run! {conn: { db_run! {conn: {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid))) diesel::delete(users::table.filter(users::uuid.eq(self.uuid)))
@ -293,13 +301,13 @@ impl User {
}} }}
} }
pub fn update_uuid_revision(uuid: &str, conn: &DbConn) { pub async fn update_uuid_revision(uuid: &str, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn) { if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {}: {:#?}", uuid, e); warn!("Failed to update revision for {}: {:#?}", uuid, e);
} }
} }
pub fn update_all_revisions(conn: &DbConn) -> EmptyResult { pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc(); let updated_at = Utc::now().naive_utc();
db_run! {conn: { db_run! {conn: {
@ -312,13 +320,13 @@ impl User {
}} }}
} }
pub fn update_revision(&mut self, conn: &DbConn) -> EmptyResult { pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc(); self.updated_at = Utc::now().naive_utc();
Self::_update_revision(&self.uuid, &self.updated_at, conn) Self::_update_revision(&self.uuid, &self.updated_at, conn).await
} }
fn _update_revision(uuid: &str, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult { async fn _update_revision(uuid: &str, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! {conn: { db_run! {conn: {
crate::util::retry(|| { crate::util::retry(|| {
diesel::update(users::table.filter(users::uuid.eq(uuid))) diesel::update(users::table.filter(users::uuid.eq(uuid)))
@ -329,7 +337,7 @@ impl User {
}} }}
} }
pub fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! {conn: { db_run! {conn: {
users::table users::table
@ -340,20 +348,20 @@ impl User {
}} }}
} }
pub fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_uuid(uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! {conn: { db_run! {conn: {
users::table.filter(users::uuid.eq(uuid)).first::<UserDb>(conn).ok().from_db() users::table.filter(users::uuid.eq(uuid)).first::<UserDb>(conn).ok().from_db()
}} }}
} }
pub fn get_all(conn: &DbConn) -> Vec<Self> { pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! {conn: { db_run! {conn: {
users::table.load::<UserDb>(conn).expect("Error loading users").from_db() users::table.load::<UserDb>(conn).expect("Error loading users").from_db()
}} }}
} }
pub fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> { pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
match Device::find_latest_active_by_user(&self.uuid, conn) { match Device::find_latest_active_by_user(&self.uuid, conn).await {
Some(device) => Some(device.updated_at), Some(device) => Some(device.updated_at),
None => None, None => None,
} }
@ -368,7 +376,7 @@ impl Invitation {
} }
} }
pub fn save(&self, conn: &DbConn) -> EmptyResult { pub async fn save(&self, conn: &DbConn) -> EmptyResult {
if self.email.trim().is_empty() { if self.email.trim().is_empty() {
err!("Invitation email can't be empty") err!("Invitation email can't be empty")
} }
@ -393,7 +401,7 @@ impl Invitation {
} }
} }
pub fn delete(self, conn: &DbConn) -> EmptyResult { pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! {conn: { db_run! {conn: {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email))) diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn) .execute(conn)
@ -401,7 +409,7 @@ impl Invitation {
}} }}
} }
pub fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> { pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase(); let lower_mail = mail.to_lowercase();
db_run! {conn: { db_run! {conn: {
invitations::table invitations::table
@ -412,9 +420,9 @@ impl Invitation {
}} }}
} }
pub fn take(mail: &str, conn: &DbConn) -> bool { pub async fn take(mail: &str, conn: &DbConn) -> bool {
match Self::find_by_mail(mail, conn) { match Self::find_by_mail(mail, conn).await {
Some(invitation) => invitation.delete(conn).is_ok(), Some(invitation) => invitation.delete(conn).await.is_ok(),
None => false, None => false,
} }
} }

2
src/db/schemas/mysql/schema.rs

@ -42,7 +42,7 @@ table! {
} }
table! { table! {
devices (uuid) { devices (uuid, user_uuid) {
uuid -> Text, uuid -> Text,
created_at -> Datetime, created_at -> Datetime,
updated_at -> Datetime, updated_at -> Datetime,

2
src/db/schemas/postgresql/schema.rs

@ -42,7 +42,7 @@ table! {
} }
table! { table! {
devices (uuid) { devices (uuid, user_uuid) {
uuid -> Text, uuid -> Text,
created_at -> Timestamp, created_at -> Timestamp,
updated_at -> Timestamp, updated_at -> Timestamp,

2
src/db/schemas/sqlite/schema.rs

@ -42,7 +42,7 @@ table! {
} }
table! { table! {
devices (uuid) { devices (uuid, user_uuid) {
uuid -> Text, uuid -> Text,
created_at -> Timestamp, created_at -> Timestamp,
updated_at -> Timestamp, updated_at -> Timestamp,

34
src/error.rs

@ -24,7 +24,7 @@ macro_rules! make_error {
} }
} }
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.error {$( match &self.error {$(
ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)), ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)),
)+} )+}
@ -45,10 +45,11 @@ use lettre::transport::smtp::Error as SmtpErr;
use openssl::error::ErrorStack as SSLErr; use openssl::error::ErrorStack as SSLErr;
use regex::Error as RegexErr; use regex::Error as RegexErr;
use reqwest::Error as ReqErr; use reqwest::Error as ReqErr;
use rocket::error::Error as RocketErr;
use serde_json::{Error as SerdeErr, Value}; use serde_json::{Error as SerdeErr, Value};
use std::io::Error as IoErr; use std::io::Error as IoErr;
use std::time::SystemTimeError as TimeErr; use std::time::SystemTimeError as TimeErr;
use u2f::u2ferror::U2fError as U2fErr; use tokio_tungstenite::tungstenite::Error as TungstError;
use webauthn_rs::error::WebauthnError as WebauthnErr; use webauthn_rs::error::WebauthnError as WebauthnErr;
use yubico::yubicoerror::YubicoError as YubiErr; use yubico::yubicoerror::YubicoError as YubiErr;
@ -69,7 +70,6 @@ make_error! {
Json(Value): _no_source, _serialize, Json(Value): _no_source, _serialize,
Db(DieselErr): _has_source, _api_error, Db(DieselErr): _has_source, _api_error,
R2d2(R2d2Err): _has_source, _api_error, R2d2(R2d2Err): _has_source, _api_error,
U2f(U2fErr): _has_source, _api_error,
Serde(SerdeErr): _has_source, _api_error, Serde(SerdeErr): _has_source, _api_error,
JWt(JwtErr): _has_source, _api_error, JWt(JwtErr): _has_source, _api_error,
Handlebars(HbErr): _has_source, _api_error, Handlebars(HbErr): _has_source, _api_error,
@ -84,14 +84,16 @@ make_error! {
Address(AddrErr): _has_source, _api_error, Address(AddrErr): _has_source, _api_error,
Smtp(SmtpErr): _has_source, _api_error, Smtp(SmtpErr): _has_source, _api_error,
OpenSSL(SSLErr): _has_source, _api_error, OpenSSL(SSLErr): _has_source, _api_error,
Rocket(RocketErr): _has_source, _api_error,
DieselCon(DieselConErr): _has_source, _api_error, DieselCon(DieselConErr): _has_source, _api_error,
DieselMig(DieselMigErr): _has_source, _api_error, DieselMig(DieselMigErr): _has_source, _api_error,
Webauthn(WebauthnErr): _has_source, _api_error, Webauthn(WebauthnErr): _has_source, _api_error,
WebSocket(TungstError): _has_source, _api_error,
} }
impl std::fmt::Debug for Error { impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.source() { match self.source() {
Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e), Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e),
None => match self.error { None => match self.error {
@ -193,8 +195,8 @@ use rocket::http::{ContentType, Status};
use rocket::request::Request; use rocket::request::Request;
use rocket::response::{self, Responder, Response}; use rocket::response::{self, Responder, Response};
impl<'r> Responder<'r> for Error { impl<'r> Responder<'r, 'static> for Error {
fn respond_to(self, _: &Request) -> response::Result<'r> { fn respond_to(self, _: &Request<'_>) -> response::Result<'static> {
match self.error { match self.error {
ErrorKind::Empty(_) => {} // Don't print the error in this situation ErrorKind::Empty(_) => {} // Don't print the error in this situation
ErrorKind::Simple(_) => {} // Don't print the error in this situation ErrorKind::Simple(_) => {} // Don't print the error in this situation
@ -202,8 +204,8 @@ impl<'r> Responder<'r> for Error {
}; };
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest); let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
let body = self.to_string();
Response::build().status(code).header(ContentType::JSON).sized_body(Cursor::new(format!("{}", self))).ok() Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
} }
} }
@ -214,20 +216,20 @@ impl<'r> Responder<'r> for Error {
macro_rules! err { macro_rules! err {
($msg:expr) => {{ ($msg:expr) => {{
error!("{}", $msg); error!("{}", $msg);
return Err(crate::error::Error::new($msg, $msg)); return Err($crate::error::Error::new($msg, $msg));
}}; }};
($usr_msg:expr, $log_value:expr) => {{ ($usr_msg:expr, $log_value:expr) => {{
error!("{}. {}", $usr_msg, $log_value); error!("{}. {}", $usr_msg, $log_value);
return Err(crate::error::Error::new($usr_msg, $log_value)); return Err($crate::error::Error::new($usr_msg, $log_value));
}}; }};
} }
macro_rules! err_silent { macro_rules! err_silent {
($msg:expr) => {{ ($msg:expr) => {{
return Err(crate::error::Error::new($msg, $msg)); return Err($crate::error::Error::new($msg, $msg));
}}; }};
($usr_msg:expr, $log_value:expr) => {{ ($usr_msg:expr, $log_value:expr) => {{
return Err(crate::error::Error::new($usr_msg, $log_value)); return Err($crate::error::Error::new($usr_msg, $log_value));
}}; }};
} }
@ -235,11 +237,11 @@ macro_rules! err_silent {
macro_rules! err_code { macro_rules! err_code {
($msg:expr, $err_code: expr) => {{ ($msg:expr, $err_code: expr) => {{
error!("{}", $msg); error!("{}", $msg);
return Err(crate::error::Error::new($msg, $msg).with_code($err_code)); return Err($crate::error::Error::new($msg, $msg).with_code($err_code));
}}; }};
($usr_msg:expr, $log_value:expr, $err_code: expr) => {{ ($usr_msg:expr, $log_value:expr, $err_code: expr) => {{
error!("{}. {}", $usr_msg, $log_value); error!("{}. {}", $usr_msg, $log_value);
return Err(crate::error::Error::new($usr_msg, $log_value).with_code($err_code)); return Err($crate::error::Error::new($usr_msg, $log_value).with_code($err_code));
}}; }};
} }
@ -247,11 +249,11 @@ macro_rules! err_code {
macro_rules! err_discard { macro_rules! err_discard {
($msg:expr, $data:expr) => {{ ($msg:expr, $data:expr) => {{
std::io::copy(&mut $data.open(), &mut std::io::sink()).ok(); std::io::copy(&mut $data.open(), &mut std::io::sink()).ok();
return Err(crate::error::Error::new($msg, $msg)); return Err($crate::error::Error::new($msg, $msg));
}}; }};
($usr_msg:expr, $log_value:expr, $data:expr) => {{ ($usr_msg:expr, $log_value:expr, $data:expr) => {{
std::io::copy(&mut $data.open(), &mut std::io::sink()).ok(); std::io::copy(&mut $data.open(), &mut std::io::sink()).ok();
return Err(crate::error::Error::new($usr_msg, $log_value)); return Err($crate::error::Error::new($usr_msg, $log_value));
}}; }};
} }

136
src/mail.rs

@ -4,11 +4,11 @@ use chrono::NaiveDateTime;
use percent_encoding::{percent_encode, NON_ALPHANUMERIC}; use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use lettre::{ use lettre::{
message::{header, Mailbox, Message, MultiPart, SinglePart}, message::{Mailbox, Message, MultiPart},
transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism}, transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism},
transport::smtp::client::{Tls, TlsParameters}, transport::smtp::client::{Tls, TlsParameters},
transport::smtp::extension::ClientId, transport::smtp::extension::ClientId,
Address, SmtpTransport, Transport, Address, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
}; };
use crate::{ use crate::{
@ -21,16 +21,16 @@ use crate::{
CONFIG, CONFIG,
}; };
fn mailer() -> SmtpTransport { fn mailer() -> AsyncSmtpTransport<Tokio1Executor> {
use std::time::Duration; use std::time::Duration;
let host = CONFIG.smtp_host().unwrap(); let host = CONFIG.smtp_host().unwrap();
let smtp_client = SmtpTransport::builder_dangerous(host.as_str()) let smtp_client = AsyncSmtpTransport::<Tokio1Executor>::builder_dangerous(host.as_str())
.port(CONFIG.smtp_port()) .port(CONFIG.smtp_port())
.timeout(Some(Duration::from_secs(CONFIG.smtp_timeout()))); .timeout(Some(Duration::from_secs(CONFIG.smtp_timeout())));
// Determine security // Determine security
let smtp_client = if CONFIG.smtp_ssl() || CONFIG.smtp_explicit_tls() { let smtp_client = if CONFIG.smtp_security() != *"off" {
let mut tls_parameters = TlsParameters::builder(host); let mut tls_parameters = TlsParameters::builder(host);
if CONFIG.smtp_accept_invalid_hostnames() { if CONFIG.smtp_accept_invalid_hostnames() {
tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true); tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true);
@ -40,7 +40,7 @@ fn mailer() -> SmtpTransport {
} }
let tls_parameters = tls_parameters.build().unwrap(); let tls_parameters = tls_parameters.build().unwrap();
if CONFIG.smtp_explicit_tls() { if CONFIG.smtp_security() == *"force_tls" {
smtp_client.tls(Tls::Wrapper(tls_parameters)) smtp_client.tls(Tls::Wrapper(tls_parameters))
} else { } else {
smtp_client.tls(Tls::Required(tls_parameters)) smtp_client.tls(Tls::Required(tls_parameters))
@ -110,7 +110,7 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String
Ok((subject, body)) Ok((subject, body))
} }
pub fn send_password_hint(address: &str, hint: Option<String>) -> EmptyResult { pub async fn send_password_hint(address: &str, hint: Option<String>) -> EmptyResult {
let template_name = if hint.is_some() { let template_name = if hint.is_some() {
"email/pw_hint_some" "email/pw_hint_some"
} else { } else {
@ -119,10 +119,10 @@ pub fn send_password_hint(address: &str, hint: Option<String>) -> EmptyResult {
let (subject, body_html, body_text) = get_text(template_name, json!({ "hint": hint, "url": CONFIG.domain() }))?; let (subject, body_html, body_text) = get_text(template_name, json!({ "hint": hint, "url": CONFIG.domain() }))?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_delete_account(address: &str, uuid: &str) -> EmptyResult { pub async fn send_delete_account(address: &str, uuid: &str) -> EmptyResult {
let claims = generate_delete_claims(uuid.to_string()); let claims = generate_delete_claims(uuid.to_string());
let delete_token = encode_jwt(&claims); let delete_token = encode_jwt(&claims);
@ -136,10 +136,10 @@ pub fn send_delete_account(address: &str, uuid: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_verify_email(address: &str, uuid: &str) -> EmptyResult { pub async fn send_verify_email(address: &str, uuid: &str) -> EmptyResult {
let claims = generate_verify_email_claims(uuid.to_string()); let claims = generate_verify_email_claims(uuid.to_string());
let verify_email_token = encode_jwt(&claims); let verify_email_token = encode_jwt(&claims);
@ -153,10 +153,10 @@ pub fn send_verify_email(address: &str, uuid: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_welcome(address: &str) -> EmptyResult { pub async fn send_welcome(address: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/welcome", "email/welcome",
json!({ json!({
@ -164,10 +164,10 @@ pub fn send_welcome(address: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_welcome_must_verify(address: &str, uuid: &str) -> EmptyResult { pub async fn send_welcome_must_verify(address: &str, uuid: &str) -> EmptyResult {
let claims = generate_verify_email_claims(uuid.to_string()); let claims = generate_verify_email_claims(uuid.to_string());
let verify_email_token = encode_jwt(&claims); let verify_email_token = encode_jwt(&claims);
@ -180,10 +180,10 @@ pub fn send_welcome_must_verify(address: &str, uuid: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_2fa_removed_from_org(address: &str, org_name: &str) -> EmptyResult { pub async fn send_2fa_removed_from_org(address: &str, org_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/send_2fa_removed_from_org", "email/send_2fa_removed_from_org",
json!({ json!({
@ -192,10 +192,10 @@ pub fn send_2fa_removed_from_org(address: &str, org_name: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_single_org_removed_from_org(address: &str, org_name: &str) -> EmptyResult { pub async fn send_single_org_removed_from_org(address: &str, org_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/send_single_org_removed_from_org", "email/send_single_org_removed_from_org",
json!({ json!({
@ -204,10 +204,10 @@ pub fn send_single_org_removed_from_org(address: &str, org_name: &str) -> EmptyR
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_invite( pub async fn send_invite(
address: &str, address: &str,
uuid: &str, uuid: &str,
org_id: Option<String>, org_id: Option<String>,
@ -236,10 +236,10 @@ pub fn send_invite(
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_invite( pub async fn send_emergency_access_invite(
address: &str, address: &str,
uuid: &str, uuid: &str,
emer_id: Option<String>, emer_id: Option<String>,
@ -267,10 +267,10 @@ pub fn send_emergency_access_invite(
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_invite_accepted(address: &str, grantee_email: &str) -> EmptyResult { pub async fn send_emergency_access_invite_accepted(address: &str, grantee_email: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/emergency_access_invite_accepted", "email/emergency_access_invite_accepted",
json!({ json!({
@ -279,10 +279,10 @@ pub fn send_emergency_access_invite_accepted(address: &str, grantee_email: &str)
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_invite_confirmed(address: &str, grantor_name: &str) -> EmptyResult { pub async fn send_emergency_access_invite_confirmed(address: &str, grantor_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/emergency_access_invite_confirmed", "email/emergency_access_invite_confirmed",
json!({ json!({
@ -291,10 +291,10 @@ pub fn send_emergency_access_invite_confirmed(address: &str, grantor_name: &str)
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_recovery_approved(address: &str, grantor_name: &str) -> EmptyResult { pub async fn send_emergency_access_recovery_approved(address: &str, grantor_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/emergency_access_recovery_approved", "email/emergency_access_recovery_approved",
json!({ json!({
@ -303,10 +303,10 @@ pub fn send_emergency_access_recovery_approved(address: &str, grantor_name: &str
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_recovery_initiated( pub async fn send_emergency_access_recovery_initiated(
address: &str, address: &str,
grantee_name: &str, grantee_name: &str,
atype: &str, atype: &str,
@ -322,10 +322,10 @@ pub fn send_emergency_access_recovery_initiated(
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_recovery_reminder( pub async fn send_emergency_access_recovery_reminder(
address: &str, address: &str,
grantee_name: &str, grantee_name: &str,
atype: &str, atype: &str,
@ -341,10 +341,10 @@ pub fn send_emergency_access_recovery_reminder(
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_recovery_rejected(address: &str, grantor_name: &str) -> EmptyResult { pub async fn send_emergency_access_recovery_rejected(address: &str, grantor_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/emergency_access_recovery_rejected", "email/emergency_access_recovery_rejected",
json!({ json!({
@ -353,10 +353,10 @@ pub fn send_emergency_access_recovery_rejected(address: &str, grantor_name: &str
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_emergency_access_recovery_timed_out(address: &str, grantee_name: &str, atype: &str) -> EmptyResult { pub async fn send_emergency_access_recovery_timed_out(address: &str, grantee_name: &str, atype: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/emergency_access_recovery_timed_out", "email/emergency_access_recovery_timed_out",
json!({ json!({
@ -366,10 +366,10 @@ pub fn send_emergency_access_recovery_timed_out(address: &str, grantee_name: &st
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_invite_accepted(new_user_email: &str, address: &str, org_name: &str) -> EmptyResult { pub async fn send_invite_accepted(new_user_email: &str, address: &str, org_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/invite_accepted", "email/invite_accepted",
json!({ json!({
@ -379,10 +379,10 @@ pub fn send_invite_accepted(new_user_email: &str, address: &str, org_name: &str)
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult { pub async fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/invite_confirmed", "email/invite_confirmed",
json!({ json!({
@ -391,10 +391,10 @@ pub fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &str) -> EmptyResult { pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &str) -> EmptyResult {
use crate::util::upcase_first; use crate::util::upcase_first;
let device = upcase_first(device); let device = upcase_first(device);
@ -409,10 +409,10 @@ pub fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, de
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_incomplete_2fa_login(address: &str, ip: &str, dt: &NaiveDateTime, device: &str) -> EmptyResult { pub async fn send_incomplete_2fa_login(address: &str, ip: &str, dt: &NaiveDateTime, device: &str) -> EmptyResult {
use crate::util::upcase_first; use crate::util::upcase_first;
let device = upcase_first(device); let device = upcase_first(device);
@ -428,10 +428,10 @@ pub fn send_incomplete_2fa_login(address: &str, ip: &str, dt: &NaiveDateTime, de
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_token(address: &str, token: &str) -> EmptyResult { pub async fn send_token(address: &str, token: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/twofactor_email", "email/twofactor_email",
json!({ json!({
@ -440,10 +440,10 @@ pub fn send_token(address: &str, token: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_change_email(address: &str, token: &str) -> EmptyResult { pub async fn send_change_email(address: &str, token: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/change_email", "email/change_email",
json!({ json!({
@ -452,10 +452,10 @@ pub fn send_change_email(address: &str, token: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
pub fn send_test(address: &str) -> EmptyResult { pub async fn send_test(address: &str) -> EmptyResult {
let (subject, body_html, body_text) = get_text( let (subject, body_html, body_text) = get_text(
"email/smtp_test", "email/smtp_test",
json!({ json!({
@ -463,43 +463,19 @@ pub fn send_test(address: &str) -> EmptyResult {
}), }),
)?; )?;
send_email(address, &subject, body_html, body_text) send_email(address, &subject, body_html, body_text).await
} }
fn send_email(address: &str, subject: &str, body_html: String, body_text: String) -> EmptyResult { async fn send_email(address: &str, subject: &str, body_html: String, body_text: String) -> EmptyResult {
let address_split: Vec<&str> = address.rsplitn(2, '@').collect();
if address_split.len() != 2 {
err!("Invalid email address (no @)");
}
let domain_puny = match idna::domain_to_ascii_strict(address_split[0]) {
Ok(d) => d,
Err(_) => err!("Can't convert email domain to ASCII representation"),
};
let address = format!("{}@{}", address_split[1], domain_puny);
let html = SinglePart::builder()
// We force Base64 encoding because in the past we had issues with different encodings.
.header(header::ContentTransferEncoding::Base64)
.header(header::ContentType::TEXT_HTML)
.body(body_html);
let text = SinglePart::builder()
// We force Base64 encoding because in the past we had issues with different encodings.
.header(header::ContentTransferEncoding::Base64)
.header(header::ContentType::TEXT_PLAIN)
.body(body_text);
let smtp_from = &CONFIG.smtp_from(); let smtp_from = &CONFIG.smtp_from();
let email = Message::builder() let email = Message::builder()
.message_id(Some(format!("<{}@{}>", crate::util::get_uuid(), smtp_from.split('@').collect::<Vec<&str>>()[1]))) .message_id(Some(format!("<{}@{}>", crate::util::get_uuid(), smtp_from.split('@').collect::<Vec<&str>>()[1])))
.to(Mailbox::new(None, Address::from_str(&address)?)) .to(Mailbox::new(None, Address::from_str(address)?))
.from(Mailbox::new(Some(CONFIG.smtp_from_name()), Address::from_str(smtp_from)?)) .from(Mailbox::new(Some(CONFIG.smtp_from_name()), Address::from_str(smtp_from)?))
.subject(subject) .subject(subject)
.multipart(MultiPart::alternative().singlepart(text).singlepart(html))?; .multipart(MultiPart::alternative_plain_html(body_text, body_html))?;
match mailer().send(&email) { match mailer().send(email).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
// Match some common errors and make them more user friendly // Match some common errors and make them more user friendly
Err(e) => { Err(e) => {

206
src/main.rs

@ -1,4 +1,30 @@
#![forbid(unsafe_code)] #![forbid(unsafe_code, non_ascii_idents)]
#![deny(
rust_2018_idioms,
rust_2021_compatibility,
noop_method_call,
pointer_structural_match,
trivial_casts,
trivial_numeric_casts,
unused_import_braces,
clippy::cast_lossless,
clippy::clone_on_ref_ptr,
clippy::equatable_if_let,
clippy::float_cmp_const,
clippy::inefficient_to_string,
clippy::linkedlist,
clippy::macro_use_imports,
clippy::manual_assert,
clippy::match_wildcard_for_single_variants,
clippy::mem_forget,
clippy::string_add_assign,
clippy::string_to_string,
clippy::unnecessary_join,
clippy::unnecessary_self_imports,
clippy::unused_async,
clippy::verbose_file_reads,
clippy::zero_sized_map_values
)]
#![cfg_attr(feature = "unstable", feature(ip))] #![cfg_attr(feature = "unstable", feature(ip))]
// The recursion_limit is mainly triggered by the json!() macro. // The recursion_limit is mainly triggered by the json!() macro.
// The more key/value pairs there are the more recursion occurs. // The more key/value pairs there are the more recursion occurs.
@ -6,7 +32,13 @@
// If you go above 128 it will cause rust-analyzer to fail, // If you go above 128 it will cause rust-analyzer to fail,
#![recursion_limit = "87"] #![recursion_limit = "87"]
extern crate openssl; // When enabled use MiMalloc as malloc instead of the default malloc
#[cfg(feature = "enable_mimalloc")]
use mimalloc::MiMalloc;
#[cfg(feature = "enable_mimalloc")]
#[cfg_attr(feature = "enable_mimalloc", global_allocator)]
static GLOBAL: MiMalloc = MiMalloc;
#[macro_use] #[macro_use]
extern crate rocket; extern crate rocket;
#[macro_use] #[macro_use]
@ -20,8 +52,19 @@ extern crate diesel;
#[macro_use] #[macro_use]
extern crate diesel_migrations; extern crate diesel_migrations;
use job_scheduler::{Job, JobScheduler}; use std::{
use std::{fs::create_dir_all, panic, path::Path, process::exit, str::FromStr, thread, time::Duration}; fs::{canonicalize, create_dir_all},
panic,
path::Path,
process::exit,
str::FromStr,
thread,
};
use tokio::{
fs::File,
io::{AsyncBufReadExt, BufReader},
};
#[macro_use] #[macro_use]
mod error; mod error;
@ -37,9 +80,11 @@ mod util;
pub use config::CONFIG; pub use config::CONFIG;
pub use error::{Error, MapResult}; pub use error::{Error, MapResult};
use rocket::data::{Limits, ToByteUnit};
pub use util::is_running_in_docker; pub use util::is_running_in_docker;
fn main() { #[rocket::main]
async fn main() -> Result<(), Error> {
parse_args(); parse_args();
launch_info(); launch_info();
@ -49,20 +94,23 @@ fn main() {
let extra_debug = matches!(level, LF::Trace | LF::Debug); let extra_debug = matches!(level, LF::Trace | LF::Debug);
check_data_folder(); check_data_folder().await;
check_rsa_keys().unwrap_or_else(|_| { check_rsa_keys().unwrap_or_else(|_| {
error!("Error creating keys, exiting..."); error!("Error creating keys, exiting...");
exit(1); exit(1);
}); });
check_web_vault(); check_web_vault();
create_icon_cache_folder(); create_dir(&CONFIG.icon_cache_folder(), "icon cache");
create_dir(&CONFIG.tmp_folder(), "tmp folder");
create_dir(&CONFIG.sends_folder(), "sends folder");
create_dir(&CONFIG.attachments_folder(), "attachments folder");
let pool = create_db_pool(); let pool = create_db_pool().await;
schedule_jobs(pool.clone()); schedule_jobs(pool.clone()).await;
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().unwrap()).unwrap(); crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap();
launch_rocket(pool, extra_debug); // Blocks until program termination. launch_rocket(pool, extra_debug).await // Blocks until program termination.
} }
const HELP: &str = "\ const HELP: &str = "\
@ -126,13 +174,13 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
// Hide failed to close stream messages // Hide failed to close stream messages
.level_for("hyper::server", log::LevelFilter::Warn) .level_for("hyper::server", log::LevelFilter::Warn)
// Silence rocket logs // Silence rocket logs
.level_for("_", log::LevelFilter::Off) .level_for("_", log::LevelFilter::Warn)
.level_for("launch", log::LevelFilter::Off) .level_for("rocket::launch", log::LevelFilter::Error)
.level_for("launch_", log::LevelFilter::Off) .level_for("rocket::launch_", log::LevelFilter::Error)
.level_for("rocket::rocket", log::LevelFilter::Off) .level_for("rocket::rocket", log::LevelFilter::Warn)
.level_for("rocket::fairing", log::LevelFilter::Off) .level_for("rocket::server", log::LevelFilter::Warn)
// Never show html5ever and hyper::proto logs, too noisy .level_for("rocket::fairing::fairings", log::LevelFilter::Warn)
.level_for("html5ever", log::LevelFilter::Off) .level_for("rocket::shield::shield", log::LevelFilter::Warn)
.level_for("hyper::proto", log::LevelFilter::Off) .level_for("hyper::proto", log::LevelFilter::Off)
.level_for("hyper::client", log::LevelFilter::Off) .level_for("hyper::client", log::LevelFilter::Off)
// Prevent cookie_store logs // Prevent cookie_store logs
@ -243,11 +291,7 @@ fn create_dir(path: &str, description: &str) {
create_dir_all(path).expect(&err_msg); create_dir_all(path).expect(&err_msg);
} }
fn create_icon_cache_folder() { async fn check_data_folder() {
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
}
fn check_data_folder() {
let data_folder = &CONFIG.data_folder(); let data_folder = &CONFIG.data_folder();
let path = Path::new(data_folder); let path = Path::new(data_folder);
if !path.exists() { if !path.exists() {
@ -259,6 +303,53 @@ fn check_data_folder() {
} }
exit(1); exit(1);
} }
if is_running_in_docker()
&& std::env::var("I_REALLY_WANT_VOLATILE_STORAGE").is_err()
&& !docker_data_folder_is_persistent(data_folder).await
{
error!(
"No persistent volume!\n\
########################################################################################\n\
# It looks like you did not configure a persistent volume! #\n\
# This will result in permanent data loss when the container is removed or updated! #\n\
# If you really want to use volatile storage set `I_REALLY_WANT_VOLATILE_STORAGE=true` #\n\
########################################################################################\n"
);
exit(1);
}
}
/// Detect when using Docker or Podman the DATA_FOLDER is either a bind-mount or a volume created manually.
/// If not created manually, then the data will not be persistent.
/// A none persistent volume in either Docker or Podman is represented by a 64 alphanumerical string.
/// If we detect this string, we will alert about not having a persistent self defined volume.
/// This probably means that someone forgot to add `-v /path/to/vaultwarden_data/:/data`
async fn docker_data_folder_is_persistent(data_folder: &str) -> bool {
if let Ok(mountinfo) = File::open("/proc/self/mountinfo").await {
// Since there can only be one mountpoint to the DATA_FOLDER
// We do a basic check for this mountpoint surrounded by a space.
let data_folder_match = if data_folder.starts_with('/') {
format!(" {data_folder} ")
} else {
format!(" /{data_folder} ")
};
let mut lines = BufReader::new(mountinfo).lines();
while let Some(line) = lines.next_line().await.unwrap_or_default() {
// Only execute a regex check if we find the base match
if line.contains(&data_folder_match) {
let re = regex::Regex::new(r"/volumes/[a-z0-9]{64}/_data /").unwrap();
if re.is_match(&line) {
return false;
}
// If we did found a match for the mountpoint, but not the regex, then still stop searching.
break;
}
}
}
// In all other cases, just assume a true.
// This is just an informative check to try and prevent data loss.
true
} }
fn check_rsa_keys() -> Result<(), crate::error::Error> { fn check_rsa_keys() -> Result<(), crate::error::Error> {
@ -275,7 +366,7 @@ fn check_rsa_keys() -> Result<(), crate::error::Error> {
} }
if !util::file_exists(&pub_path) { if !util::file_exists(&pub_path) {
let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&util::read_file(&priv_path)?)?; let rsa_key = openssl::rsa::Rsa::private_key_from_pem(&std::fs::read(&priv_path)?)?;
let pub_key = rsa_key.public_key_to_pem()?; let pub_key = rsa_key.public_key_to_pem()?;
crate::util::write_file(&pub_path, &pub_key)?; crate::util::write_file(&pub_path, &pub_key)?;
@ -304,8 +395,8 @@ fn check_web_vault() {
} }
} }
fn create_db_pool() -> db::DbPool { async fn create_db_pool() -> db::DbPool {
match util::retry_db(db::DbPool::from_config, CONFIG.db_connection_retries()) { match util::retry_db(db::DbPool::from_config, CONFIG.db_connection_retries()).await {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
error!("Error creating database pool: {:?}", e); error!("Error creating database pool: {:?}", e);
@ -314,51 +405,74 @@ fn create_db_pool() -> db::DbPool {
} }
} }
fn launch_rocket(pool: db::DbPool, extra_debug: bool) { async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
let basepath = &CONFIG.domain_path(); let basepath = &CONFIG.domain_path();
let mut config = rocket::Config::from(rocket::Config::figment());
config.temp_dir = canonicalize(CONFIG.tmp_folder()).unwrap().into();
config.cli_colors = false; // Make sure Rocket does not color any values for logging.
config.limits = Limits::new()
.limit("json", 20.megabytes()) // 20MB should be enough for very large imports, something like 5000+ vault entries
.limit("data-form", 525.megabytes()) // This needs to match the maximum allowed file size for Send
.limit("file", 525.megabytes()); // This needs to match the maximum allowed file size for attachments
// If adding more paths here, consider also adding them to // If adding more paths here, consider also adding them to
// crate::utils::LOGGED_ROUTES to make sure they appear in the log // crate::utils::LOGGED_ROUTES to make sure they appear in the log
let result = rocket::ignite() let instance = rocket::custom(config)
.mount(&[basepath, "/"].concat(), api::web_routes()) .mount([basepath, "/"].concat(), api::web_routes())
.mount(&[basepath, "/api"].concat(), api::core_routes()) .mount([basepath, "/api"].concat(), api::core_routes())
.mount(&[basepath, "/admin"].concat(), api::admin_routes()) .mount([basepath, "/admin"].concat(), api::admin_routes())
.mount(&[basepath, "/identity"].concat(), api::identity_routes()) .mount([basepath, "/identity"].concat(), api::identity_routes())
.mount(&[basepath, "/icons"].concat(), api::icons_routes()) .mount([basepath, "/icons"].concat(), api::icons_routes())
.mount(&[basepath, "/notifications"].concat(), api::notifications_routes()) .mount([basepath, "/notifications"].concat(), api::notifications_routes())
.manage(pool) .manage(pool)
.manage(api::start_notification_server()) .manage(api::start_notification_server())
.attach(util::AppHeaders()) .attach(util::AppHeaders())
.attach(util::Cors()) .attach(util::Cors())
.attach(util::BetterLogging(extra_debug)) .attach(util::BetterLogging(extra_debug))
.launch(); .ignite()
.await?;
// Launch and print error if there is one CONFIG.set_rocket_shutdown_handle(instance.shutdown());
// The launch will restore the original logging level ctrlc::set_handler(move || {
error!("Launch error {:#?}", result); info!("Exiting vaultwarden!");
CONFIG.shutdown();
})
.expect("Error setting Ctrl-C handler");
let _ = instance.launch().await?;
info!("Vaultwarden process exited!");
Ok(())
} }
fn schedule_jobs(pool: db::DbPool) { async fn schedule_jobs(pool: db::DbPool) {
if CONFIG.job_poll_interval_ms() == 0 { if CONFIG.job_poll_interval_ms() == 0 {
info!("Job scheduler disabled."); info!("Job scheduler disabled.");
return; return;
} }
let runtime = tokio::runtime::Runtime::new().unwrap();
thread::Builder::new() thread::Builder::new()
.name("job-scheduler".to_string()) .name("job-scheduler".to_string())
.spawn(move || { .spawn(move || {
use job_scheduler_ng::{Job, JobScheduler};
let _runtime_guard = runtime.enter();
let mut sched = JobScheduler::new(); let mut sched = JobScheduler::new();
// Purge sends that are past their deletion date. // Purge sends that are past their deletion date.
if !CONFIG.send_purge_schedule().is_empty() { if !CONFIG.send_purge_schedule().is_empty() {
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || { sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
api::purge_sends(pool.clone()); runtime.spawn(api::purge_sends(pool.clone()));
})); }));
} }
// Purge trashed items that are old enough to be auto-deleted. // Purge trashed items that are old enough to be auto-deleted.
if !CONFIG.trash_purge_schedule().is_empty() { if !CONFIG.trash_purge_schedule().is_empty() {
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || { sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
api::purge_trashed_ciphers(pool.clone()); runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
})); }));
} }
@ -366,7 +480,7 @@ fn schedule_jobs(pool: db::DbPool) {
// indicates that a user's master password has been compromised. // indicates that a user's master password has been compromised.
if !CONFIG.incomplete_2fa_schedule().is_empty() { if !CONFIG.incomplete_2fa_schedule().is_empty() {
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || { sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
api::send_incomplete_2fa_notifications(pool.clone()); runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
})); }));
} }
@ -375,7 +489,7 @@ fn schedule_jobs(pool: db::DbPool) {
// sending reminders for requests that are about to be granted anyway. // sending reminders for requests that are about to be granted anyway.
if !CONFIG.emergency_request_timeout_schedule().is_empty() { if !CONFIG.emergency_request_timeout_schedule().is_empty() {
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || { sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
api::emergency_request_timeout_job(pool.clone()); runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
})); }));
} }
@ -383,7 +497,7 @@ fn schedule_jobs(pool: db::DbPool) {
// emergency access requests. // emergency access requests.
if !CONFIG.emergency_notification_reminder_schedule().is_empty() { if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || { sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
api::emergency_notification_reminder_job(pool.clone()); runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
})); }));
} }
@ -398,7 +512,9 @@ fn schedule_jobs(pool: db::DbPool) {
// tick, the one that was added earlier will run first. // tick, the one that was added earlier will run first.
loop { loop {
sched.tick(); sched.tick();
thread::sleep(Duration::from_millis(CONFIG.job_poll_interval_ms())); runtime.block_on(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(CONFIG.job_poll_interval_ms())).await
});
} }
}) })
.expect("Error spawning job scheduler thread"); .expect("Error spawning job scheduler thread");

17
src/static/global_domains.json

@ -328,6 +328,7 @@
"Type": 33, "Type": 33,
"Domains": [ "Domains": [
"healthcare.gov", "healthcare.gov",
"cuidadodesalud.gov",
"cms.gov" "cms.gov"
], ],
"Excluded": false "Excluded": false
@ -902,6 +903,7 @@
{ {
"Type": 85, "Type": 85,
"Domains": [ "Domains": [
"proton.me",
"protonmail.com", "protonmail.com",
"protonvpn.com" "protonvpn.com"
], ],
@ -922,5 +924,20 @@
"wise.com" "wise.com"
], ],
"Excluded": false "Excluded": false
},
{
"Type": 88,
"Domains": [
"takeaway.com",
"just-eat.dk",
"just-eat.no",
"just-eat.fr",
"just-eat.ch",
"lieferando.de",
"lieferando.at",
"thuisbezorgd.nl",
"pyszne.pl"
],
"Excluded": false
} }
] ]

6042
src/static/scripts/bootstrap-native.js

File diff suppressed because it is too large

3674
src/static/scripts/bootstrap.css

File diff suppressed because it is too large

276
src/static/scripts/datatables.css

@ -4,22 +4,175 @@
* *
* To rebuild or modify this file with the latest versions of the included * To rebuild or modify this file with the latest versions of the included
* software please visit: * software please visit:
* https://datatables.net/download/#bs5/dt-1.11.3 * https://datatables.net/download/#bs5/dt-1.12.1
* *
* Included libraries: * Included libraries:
* DataTables 1.11.3 * DataTables 1.12.1
*/ */
@charset "UTF-8"; @charset "UTF-8";
td.dt-control { table.dataTable td.dt-control {
background: url("https://www.datatables.net/examples/resources/details_open.png") no-repeat center center; text-align: center;
cursor: pointer;
}
table.dataTable td.dt-control:before {
height: 1em;
width: 1em;
margin-top: -9px;
display: inline-block;
color: white;
border: 0.15em solid white;
border-radius: 1em;
box-shadow: 0 0 0.2em #444;
box-sizing: content-box;
text-align: center;
text-indent: 0 !important;
font-family: "Courier New", Courier, monospace;
line-height: 1em;
content: "+";
background-color: #31b131;
}
table.dataTable tr.dt-hasChild td.dt-control:before {
content: "-";
background-color: #d33333;
}
table.dataTable thead > tr > th.sorting, table.dataTable thead > tr > th.sorting_asc, table.dataTable thead > tr > th.sorting_desc, table.dataTable thead > tr > th.sorting_asc_disabled, table.dataTable thead > tr > th.sorting_desc_disabled,
table.dataTable thead > tr > td.sorting,
table.dataTable thead > tr > td.sorting_asc,
table.dataTable thead > tr > td.sorting_desc,
table.dataTable thead > tr > td.sorting_asc_disabled,
table.dataTable thead > tr > td.sorting_desc_disabled {
cursor: pointer; cursor: pointer;
position: relative;
padding-right: 26px;
}
table.dataTable thead > tr > th.sorting:before, table.dataTable thead > tr > th.sorting:after, table.dataTable thead > tr > th.sorting_asc:before, table.dataTable thead > tr > th.sorting_asc:after, table.dataTable thead > tr > th.sorting_desc:before, table.dataTable thead > tr > th.sorting_desc:after, table.dataTable thead > tr > th.sorting_asc_disabled:before, table.dataTable thead > tr > th.sorting_asc_disabled:after, table.dataTable thead > tr > th.sorting_desc_disabled:before, table.dataTable thead > tr > th.sorting_desc_disabled:after,
table.dataTable thead > tr > td.sorting:before,
table.dataTable thead > tr > td.sorting:after,
table.dataTable thead > tr > td.sorting_asc:before,
table.dataTable thead > tr > td.sorting_asc:after,
table.dataTable thead > tr > td.sorting_desc:before,
table.dataTable thead > tr > td.sorting_desc:after,
table.dataTable thead > tr > td.sorting_asc_disabled:before,
table.dataTable thead > tr > td.sorting_asc_disabled:after,
table.dataTable thead > tr > td.sorting_desc_disabled:before,
table.dataTable thead > tr > td.sorting_desc_disabled:after {
position: absolute;
display: block;
opacity: 0.125;
right: 10px;
line-height: 9px;
font-size: 0.9em;
}
table.dataTable thead > tr > th.sorting:before, table.dataTable thead > tr > th.sorting_asc:before, table.dataTable thead > tr > th.sorting_desc:before, table.dataTable thead > tr > th.sorting_asc_disabled:before, table.dataTable thead > tr > th.sorting_desc_disabled:before,
table.dataTable thead > tr > td.sorting:before,
table.dataTable thead > tr > td.sorting_asc:before,
table.dataTable thead > tr > td.sorting_desc:before,
table.dataTable thead > tr > td.sorting_asc_disabled:before,
table.dataTable thead > tr > td.sorting_desc_disabled:before {
bottom: 50%;
content: "▴";
}
table.dataTable thead > tr > th.sorting:after, table.dataTable thead > tr > th.sorting_asc:after, table.dataTable thead > tr > th.sorting_desc:after, table.dataTable thead > tr > th.sorting_asc_disabled:after, table.dataTable thead > tr > th.sorting_desc_disabled:after,
table.dataTable thead > tr > td.sorting:after,
table.dataTable thead > tr > td.sorting_asc:after,
table.dataTable thead > tr > td.sorting_desc:after,
table.dataTable thead > tr > td.sorting_asc_disabled:after,
table.dataTable thead > tr > td.sorting_desc_disabled:after {
top: 50%;
content: "▾";
}
table.dataTable thead > tr > th.sorting_asc:before, table.dataTable thead > tr > th.sorting_desc:after,
table.dataTable thead > tr > td.sorting_asc:before,
table.dataTable thead > tr > td.sorting_desc:after {
opacity: 0.6;
}
table.dataTable thead > tr > th.sorting_desc_disabled:after, table.dataTable thead > tr > th.sorting_asc_disabled:before,
table.dataTable thead > tr > td.sorting_desc_disabled:after,
table.dataTable thead > tr > td.sorting_asc_disabled:before {
display: none;
}
table.dataTable thead > tr > th:active,
table.dataTable thead > tr > td:active {
outline: none;
}
div.dataTables_scrollBody table.dataTable thead > tr > th:before, div.dataTables_scrollBody table.dataTable thead > tr > th:after,
div.dataTables_scrollBody table.dataTable thead > tr > td:before,
div.dataTables_scrollBody table.dataTable thead > tr > td:after {
display: none;
} }
tr.dt-hasChild td.dt-control { div.dataTables_processing {
background: url("https://www.datatables.net/examples/resources/details_close.png") no-repeat center center; position: absolute;
top: 50%;
left: 50%;
width: 200px;
margin-left: -100px;
margin-top: -26px;
text-align: center;
padding: 2px;
}
div.dataTables_processing > div:last-child {
position: relative;
width: 80px;
height: 15px;
margin: 1em auto;
}
div.dataTables_processing > div:last-child > div {
position: absolute;
top: 0;
width: 13px;
height: 13px;
border-radius: 50%;
background: rgba(13, 110, 253, 0.9);
animation-timing-function: cubic-bezier(0, 1, 1, 0);
}
div.dataTables_processing > div:last-child > div:nth-child(1) {
left: 8px;
animation: datatables-loader-1 0.6s infinite;
}
div.dataTables_processing > div:last-child > div:nth-child(2) {
left: 8px;
animation: datatables-loader-2 0.6s infinite;
}
div.dataTables_processing > div:last-child > div:nth-child(3) {
left: 32px;
animation: datatables-loader-2 0.6s infinite;
}
div.dataTables_processing > div:last-child > div:nth-child(4) {
left: 56px;
animation: datatables-loader-3 0.6s infinite;
} }
@keyframes datatables-loader-1 {
0% {
transform: scale(0);
}
100% {
transform: scale(1);
}
}
@keyframes datatables-loader-3 {
0% {
transform: scale(1);
}
100% {
transform: scale(0);
}
}
@keyframes datatables-loader-2 {
0% {
transform: translate(0, 0);
}
100% {
transform: translate(24px, 0);
}
}
table.dataTable.nowrap th, table.dataTable.nowrap td {
white-space: nowrap;
}
table.dataTable th.dt-left, table.dataTable th.dt-left,
table.dataTable td.dt-left { table.dataTable td.dt-left {
text-align: left; text-align: left;
@ -41,6 +194,12 @@ table.dataTable th.dt-nowrap,
table.dataTable td.dt-nowrap { table.dataTable td.dt-nowrap {
white-space: nowrap; white-space: nowrap;
} }
table.dataTable thead th,
table.dataTable thead td,
table.dataTable tfoot th,
table.dataTable tfoot td {
text-align: left;
}
table.dataTable thead th.dt-head-left, table.dataTable thead th.dt-head-left,
table.dataTable thead td.dt-head-left, table.dataTable thead td.dt-head-left,
table.dataTable tfoot th.dt-head-left, table.dataTable tfoot th.dt-head-left,
@ -118,6 +277,28 @@ table.dataTable.nowrap th,
table.dataTable.nowrap td { table.dataTable.nowrap td {
white-space: nowrap; white-space: nowrap;
} }
table.dataTable.table-striped > tbody > tr:nth-of-type(2n+1) > * {
box-shadow: none;
}
table.dataTable > tbody > tr {
background-color: transparent;
}
table.dataTable > tbody > tr.selected > * {
box-shadow: inset 0 0 0 9999px rgba(13, 110, 253, 0.9);
color: white;
}
table.dataTable.table-striped > tbody > tr.odd > * {
box-shadow: inset 0 0 0 9999px rgba(0, 0, 0, 0.05);
}
table.dataTable.table-striped > tbody > tr.odd.selected > * {
box-shadow: inset 0 0 0 9999px rgba(13, 110, 253, 0.95);
}
table.dataTable.table-hover > tbody > tr:hover > * {
box-shadow: inset 0 0 0 9999px rgba(0, 0, 0, 0.075);
}
table.dataTable.table-hover > tbody > tr.selected:hover > * {
box-shadow: inset 0 0 0 9999px rgba(13, 110, 253, 0.975);
}
div.dataTables_wrapper div.dataTables_length label { div.dataTables_wrapper div.dataTables_length label {
font-weight: normal; font-weight: normal;
@ -154,71 +335,6 @@ div.dataTables_wrapper div.dataTables_paginate ul.pagination {
white-space: nowrap; white-space: nowrap;
justify-content: flex-end; justify-content: flex-end;
} }
div.dataTables_wrapper div.dataTables_processing {
position: absolute;
top: 50%;
left: 50%;
width: 200px;
margin-left: -100px;
margin-top: -26px;
text-align: center;
padding: 1em 0;
}
table.dataTable > thead > tr > th:active,
table.dataTable > thead > tr > td:active {
outline: none;
}
table.dataTable > thead > tr > th:not(.sorting_disabled),
table.dataTable > thead > tr > td:not(.sorting_disabled) {
padding-right: 30px;
}
table.dataTable > thead .sorting,
table.dataTable > thead .sorting_asc,
table.dataTable > thead .sorting_desc,
table.dataTable > thead .sorting_asc_disabled,
table.dataTable > thead .sorting_desc_disabled {
cursor: pointer;
position: relative;
}
table.dataTable > thead .sorting:before, table.dataTable > thead .sorting:after,
table.dataTable > thead .sorting_asc:before,
table.dataTable > thead .sorting_asc:after,
table.dataTable > thead .sorting_desc:before,
table.dataTable > thead .sorting_desc:after,
table.dataTable > thead .sorting_asc_disabled:before,
table.dataTable > thead .sorting_asc_disabled:after,
table.dataTable > thead .sorting_desc_disabled:before,
table.dataTable > thead .sorting_desc_disabled:after {
position: absolute;
bottom: 0.5em;
display: block;
opacity: 0.3;
}
table.dataTable > thead .sorting:before,
table.dataTable > thead .sorting_asc:before,
table.dataTable > thead .sorting_desc:before,
table.dataTable > thead .sorting_asc_disabled:before,
table.dataTable > thead .sorting_desc_disabled:before {
right: 1em;
content: "↑";
}
table.dataTable > thead .sorting:after,
table.dataTable > thead .sorting_asc:after,
table.dataTable > thead .sorting_desc:after,
table.dataTable > thead .sorting_asc_disabled:after,
table.dataTable > thead .sorting_desc_disabled:after {
right: 0.5em;
content: "↓";
}
table.dataTable > thead .sorting_asc:before,
table.dataTable > thead .sorting_desc:after {
opacity: 1;
}
table.dataTable > thead .sorting_asc_disabled:before,
table.dataTable > thead .sorting_desc_disabled:after {
opacity: 0;
}
div.dataTables_scrollHead table.dataTable { div.dataTables_scrollHead table.dataTable {
margin-bottom: 0 !important; margin-bottom: 0 !important;
@ -264,17 +380,6 @@ div.dataTables_wrapper div.dataTables_paginate {
table.dataTable.table-sm > thead > tr > th:not(.sorting_disabled) { table.dataTable.table-sm > thead > tr > th:not(.sorting_disabled) {
padding-right: 20px; padding-right: 20px;
} }
table.dataTable.table-sm .sorting:before,
table.dataTable.table-sm .sorting_asc:before,
table.dataTable.table-sm .sorting_desc:before {
top: 5px;
right: 0.85em;
}
table.dataTable.table-sm .sorting:after,
table.dataTable.table-sm .sorting_asc:after,
table.dataTable.table-sm .sorting_desc:after {
top: 5px;
}
table.table-bordered.dataTable { table.table-bordered.dataTable {
border-right-width: 0; border-right-width: 0;
@ -316,11 +421,4 @@ div.table-responsive > div.dataTables_wrapper > div.row > div[class^=col-]:last-
padding-right: 0; padding-right: 0;
} }
table.dataTable.table-striped > tbody > tr:nth-of-type(2n+1) {
--bs-table-accent-bg: transparent;
}
table.dataTable.table-striped > tbody > tr.odd {
--bs-table-accent-bg: var(--bs-table-striped-bg);
}

460
src/static/scripts/datatables.js

@ -4,24 +4,23 @@
* *
* To rebuild or modify this file with the latest versions of the included * To rebuild or modify this file with the latest versions of the included
* software please visit: * software please visit:
* https://datatables.net/download/#bs5/dt-1.11.3 * https://datatables.net/download/#bs5/dt-1.12.1
* *
* Included libraries: * Included libraries:
* DataTables 1.11.3 * DataTables 1.12.1
*/ */
/*! DataTables 1.11.3 /*! DataTables 1.12.1
* ©2008-2021 SpryMedia Ltd - datatables.net/license * ©2008-2022 SpryMedia Ltd - datatables.net/license
*/ */
/** /**
* @summary DataTables * @summary DataTables
* @description Paginate, search and order HTML tables * @description Paginate, search and order HTML tables
* @version 1.11.3 * @version 1.12.1
* @file jquery.dataTables.js
* @author SpryMedia Ltd * @author SpryMedia Ltd
* @contact www.datatables.net * @contact www.datatables.net
* @copyright Copyright 2008-2021 SpryMedia Ltd. * @copyright SpryMedia Ltd.
* *
* This source file is free software, available under the following license: * This source file is free software, available under the following license:
* MIT license - http://datatables.net/license * MIT license - http://datatables.net/license
@ -71,38 +70,7 @@
(function( $, window, document, undefined ) { (function( $, window, document, undefined ) {
"use strict"; "use strict";
/**
* DataTables is a plug-in for the jQuery Javascript library. It is a highly
* flexible tool, based upon the foundations of progressive enhancement,
* which will add advanced interaction controls to any HTML table. For a
* full list of features please refer to
* [DataTables.net](href="http://datatables.net).
*
* Note that the `DataTable` object is not a global variable but is aliased
* to `jQuery.fn.DataTable` and `jQuery.fn.dataTable` through which it may
* be accessed.
*
* @class
* @param {object} [init={}] Configuration object for DataTables. Options
* are defined by {@link DataTable.defaults}
* @requires jQuery 1.7+
*
* @example
* // Basic initialisation
* $(document).ready( function {
* $('#example').dataTable();
* } );
*
* @example
* // Initialisation with configuration options - in this case, disable
* // pagination and sorting.
* $(document).ready( function {
* $('#example').dataTable( {
* "paginate": false,
* "sort": false
* } );
* } );
*/
var DataTable = function ( selector, options ) var DataTable = function ( selector, options )
{ {
// When creating with `new`, create a new DataTable, returning the API instance // When creating with `new`, create a new DataTable, returning the API instance
@ -113,7 +81,7 @@
// Argument switching // Argument switching
options = selector; options = selector;
} }
/** /**
* Perform a jQuery selector action on the table's TR elements (from the tbody) and * Perform a jQuery selector action on the table's TR elements (from the tbody) and
* return the resulting jQuery object. * return the resulting jQuery object.
@ -869,24 +837,24 @@
*/ */
this.fnVersionCheck = _ext.fnVersionCheck; this.fnVersionCheck = _ext.fnVersionCheck;
var _that = this; var _that = this;
var emptyInit = options === undefined; var emptyInit = options === undefined;
var len = this.length; var len = this.length;
if ( emptyInit ) { if ( emptyInit ) {
options = {}; options = {};
} }
this.oApi = this.internal = _ext.internal; this.oApi = this.internal = _ext.internal;
// Extend with old style plug-in API methods // Extend with old style plug-in API methods
for ( var fn in DataTable.ext.internal ) { for ( var fn in DataTable.ext.internal ) {
if ( fn ) { if ( fn ) {
this[fn] = _fnExternApiFunc(fn); this[fn] = _fnExternApiFunc(fn);
} }
} }
this.each(function() { this.each(function() {
// For each initialisation we want to give it a clean initialisation // For each initialisation we want to give it a clean initialisation
// object that can be bashed around // object that can be bashed around
@ -894,7 +862,7 @@
var oInit = len > 1 ? // optimisation for single table case var oInit = len > 1 ? // optimisation for single table case
_fnExtend( o, options, true ) : _fnExtend( o, options, true ) :
options; options;
/*global oInit,_that,emptyInit*/ /*global oInit,_that,emptyInit*/
var i=0, iLen, j, jLen, k, kLen; var i=0, iLen, j, jLen, k, kLen;
var sId = this.getAttribute( 'id' ); var sId = this.getAttribute( 'id' );
@ -1108,7 +1076,7 @@
success: function ( json ) { success: function ( json ) {
_fnCamelToHungarian( defaults.oLanguage, json ); _fnCamelToHungarian( defaults.oLanguage, json );
_fnLanguageCompat( json ); _fnLanguageCompat( json );
$.extend( true, oLanguage, json ); $.extend( true, oLanguage, json, oSettings.oInit.oLanguage );
_fnCallbackFire( oSettings, null, 'i18n', [oSettings]); _fnCallbackFire( oSettings, null, 'i18n', [oSettings]);
_fnInitialise( oSettings ); _fnInitialise( oSettings );
@ -1337,7 +1305,7 @@
_that = null; _that = null;
return this; return this;
}; };
/* /*
* It is useful to have variables which are scoped locally so only the * It is useful to have variables which are scoped locally so only the
@ -2341,9 +2309,17 @@
th.addClass( oOptions.sClass ); th.addClass( oOptions.sClass );
} }
var origClass = oCol.sClass;
$.extend( oCol, oOptions ); $.extend( oCol, oOptions );
_fnMap( oCol, oOptions, "sWidth", "sWidthOrig" ); _fnMap( oCol, oOptions, "sWidth", "sWidthOrig" );
// Merge class from previously defined classes with this one, rather than just
// overwriting it in the extend above
if (origClass !== oCol.sClass) {
oCol.sClass = origClass + ' ' + oCol.sClass;
}
/* iDataSort to be applied (backwards compatibility), but aDataSort will take /* iDataSort to be applied (backwards compatibility), but aDataSort will take
* priority if defined * priority if defined
*/ */
@ -2616,9 +2592,11 @@
def = aoColDefs[i]; def = aoColDefs[i];
/* Each definition can target multiple columns, as it is an array */ /* Each definition can target multiple columns, as it is an array */
var aTargets = def.targets !== undefined ? var aTargets = def.target !== undefined
def.targets : ? def.target
def.aTargets; : def.targets !== undefined
? def.targets
: def.aTargets;
if ( ! Array.isArray( aTargets ) ) if ( ! Array.isArray( aTargets ) )
{ {
@ -3462,6 +3440,9 @@
*/ */
function _fnDraw( oSettings, ajaxComplete ) function _fnDraw( oSettings, ajaxComplete )
{ {
// Allow for state saving and a custom start position
_fnStart( oSettings );
/* Provide a pre-callback function which can be used to cancel the draw is false is returned */ /* Provide a pre-callback function which can be used to cancel the draw is false is returned */
var aPreDraw = _fnCallbackFire( oSettings, 'aoPreDrawCallback', 'preDraw', [oSettings] ); var aPreDraw = _fnCallbackFire( oSettings, 'aoPreDrawCallback', 'preDraw', [oSettings] );
if ( $.inArray( false, aPreDraw ) !== -1 ) if ( $.inArray( false, aPreDraw ) !== -1 )
@ -3470,34 +3451,18 @@
return; return;
} }
var i, iLen, n;
var anRows = []; var anRows = [];
var iRowCount = 0; var iRowCount = 0;
var asStripeClasses = oSettings.asStripeClasses; var asStripeClasses = oSettings.asStripeClasses;
var iStripes = asStripeClasses.length; var iStripes = asStripeClasses.length;
var iOpenRows = oSettings.aoOpenRows.length;
var oLang = oSettings.oLanguage; var oLang = oSettings.oLanguage;
var iInitDisplayStart = oSettings.iInitDisplayStart;
var bServerSide = _fnDataSource( oSettings ) == 'ssp'; var bServerSide = _fnDataSource( oSettings ) == 'ssp';
var aiDisplay = oSettings.aiDisplay; var aiDisplay = oSettings.aiDisplay;
oSettings.bDrawing = true;
/* Check and see if we have an initial draw position from state saving */
if ( iInitDisplayStart !== undefined && iInitDisplayStart !== -1 )
{
oSettings._iDisplayStart = bServerSide ?
iInitDisplayStart :
iInitDisplayStart >= oSettings.fnRecordsDisplay() ?
0 :
iInitDisplayStart;
oSettings.iInitDisplayStart = -1;
}
var iDisplayStart = oSettings._iDisplayStart; var iDisplayStart = oSettings._iDisplayStart;
var iDisplayEnd = oSettings.fnDisplayEnd(); var iDisplayEnd = oSettings.fnDisplayEnd();
oSettings.bDrawing = true;
/* Server-side processing draw intercept */ /* Server-side processing draw intercept */
if ( oSettings.bDeferLoading ) if ( oSettings.bDeferLoading )
{ {
@ -3899,6 +3864,28 @@
return aReturn; return aReturn;
} }
/**
* Set the start position for draw
* @param {object} oSettings dataTables settings object
*/
function _fnStart( oSettings )
{
var bServerSide = _fnDataSource( oSettings ) == 'ssp';
var iInitDisplayStart = oSettings.iInitDisplayStart;
// Check and see if we have an initial draw position from state saving
if ( iInitDisplayStart !== undefined && iInitDisplayStart !== -1 )
{
oSettings._iDisplayStart = bServerSide ?
iInitDisplayStart :
iInitDisplayStart >= oSettings.fnRecordsDisplay() ?
0 :
iInitDisplayStart;
oSettings.iInitDisplayStart = -1;
}
}
/** /**
* Create an Ajax call based on the table's settings, taking into account that * Create an Ajax call based on the table's settings, taking into account that
* parameters can have multiple forms, and backwards compatibility. * parameters can have multiple forms, and backwards compatibility.
@ -3942,8 +3929,8 @@
var ajax = oSettings.ajax; var ajax = oSettings.ajax;
var instance = oSettings.oInstance; var instance = oSettings.oInstance;
var callback = function ( json ) { var callback = function ( json ) {
var status = oSettings.jqXhr var status = oSettings.jqXHR
? oSettings.jqXhr.status ? oSettings.jqXHR.status
: null; : null;
if ( json === null || (typeof status === 'number' && status == 204 ) ) { if ( json === null || (typeof status === 'number' && status == 204 ) ) {
@ -5111,6 +5098,7 @@
'class': settings.oClasses.sProcessing 'class': settings.oClasses.sProcessing
} ) } )
.html( settings.oLanguage.sProcessing ) .html( settings.oLanguage.sProcessing )
.append('<div><div></div><div></div><div></div><div></div></div>')
.insertBefore( settings.nTable )[0]; .insertBefore( settings.nTable )[0];
} }
@ -5360,6 +5348,7 @@
footerCopy = footer.clone().prependTo( table ); footerCopy = footer.clone().prependTo( table );
footerTrgEls = footer.find('tr'); // the original tfoot is in its own table and must be sized footerTrgEls = footer.find('tr'); // the original tfoot is in its own table and must be sized
footerSrcEls = footerCopy.find('tr'); footerSrcEls = footerCopy.find('tr');
footerCopy.find('[id]').removeAttr('id');
} }
// Clone the current header and footer elements and then place it into the inner table // Clone the current header and footer elements and then place it into the inner table
@ -5367,6 +5356,7 @@
headerTrgEls = header.find('tr'); // original header is in its own table headerTrgEls = header.find('tr'); // original header is in its own table
headerSrcEls = headerCopy.find('tr'); headerSrcEls = headerCopy.find('tr');
headerCopy.find('th, td').removeAttr('tabindex'); headerCopy.find('th, td').removeAttr('tabindex');
headerCopy.find('[id]').removeAttr('id');
/* /*
@ -5440,7 +5430,7 @@
nToSize.style.width = headerWidths[i]; nToSize.style.width = headerWidths[i];
}, headerTrgEls ); }, headerTrgEls );
$(headerSrcEls).height(0); $(headerSrcEls).css('height', 0);
/* Same again with the footer if we have one */ /* Same again with the footer if we have one */
if ( footer ) if ( footer )
@ -5487,7 +5477,7 @@
// Sanity check that the table is of a sensible width. If not then we are going to get // Sanity check that the table is of a sensible width. If not then we are going to get
// misalignment - try to prevent this by not allowing the table to shrink below its min width // misalignment - try to prevent this by not allowing the table to shrink below its min width
if ( table.outerWidth() < sanityWidth ) if ( Math.round(table.outerWidth()) < Math.round(sanityWidth) )
{ {
// The min width depends upon if we have a vertical scrollbar visible or not */ // The min width depends upon if we have a vertical scrollbar visible or not */
correction = ((divBodyEl.scrollHeight > divBodyEl.offsetHeight || correction = ((divBodyEl.scrollHeight > divBodyEl.offsetHeight ||
@ -6493,16 +6483,27 @@
// Store the saved state so it might be accessed at any time // Store the saved state so it might be accessed at any time
settings.oLoadedState = $.extend( true, {}, s ); settings.oLoadedState = $.extend( true, {}, s );
// Page Length
if ( s.length !== undefined ) {
// If already initialised just set the value directly so that the select element is also updated
if (api) {
api.page.len(s.length)
}
else {
settings._iDisplayLength = s.length;
}
}
// Restore key features - todo - for 1.11 this needs to be done by // Restore key features - todo - for 1.11 this needs to be done by
// subscribed events // subscribed events
if ( s.start !== undefined ) { if ( s.start !== undefined ) {
settings._iDisplayStart = s.start;
if(api === null) { if(api === null) {
settings._iDisplayStart = s.start;
settings.iInitDisplayStart = s.start; settings.iInitDisplayStart = s.start;
} }
} else {
if ( s.length !== undefined ) { _fnPageChange(settings, s.start/settings._iDisplayLength);
settings._iDisplayLength = s.length; }
} }
// Order // Order
@ -6844,7 +6845,7 @@
return 'dom'; return 'dom';
} }
/** /**
@ -7236,8 +7237,10 @@
pluck: function ( prop ) pluck: function ( prop )
{ {
let fn = DataTable.util.get(prop);
return this.map( function ( el ) { return this.map( function ( el ) {
return el[ prop ]; return fn(el);
} ); } );
}, },
@ -8331,22 +8334,35 @@
$(document).on('plugin-init.dt', function (e, context) { $(document).on('plugin-init.dt', function (e, context) {
var api = new _Api( context ); var api = new _Api( context );
api.on( 'stateSaveParams', function ( e, settings, data ) {
var indexes = api.rows().iterator( 'row', function ( settings, idx ) {
return settings.aoData[idx]._detailsShow ? idx : undefined;
});
data.childRows = api.rows( indexes ).ids( true ).toArray(); api.on( 'stateSaveParams', function ( e, settings, d ) {
// This could be more compact with the API, but it is a lot faster as a simple
// internal loop
var idFn = settings.rowIdFn;
var data = settings.aoData;
var ids = [];
for (var i=0 ; i<data.length ; i++) {
if (data[i]._detailsShow) {
ids.push( '#' + idFn(data[i]._aData) );
}
}
d.childRows = ids;
}) })
var loaded = api.state.loaded(); var loaded = api.state.loaded();
if ( loaded && loaded.childRows ) { if ( loaded && loaded.childRows ) {
api.rows( loaded.childRows ).every( function () { api
_fnCallbackFire( context, null, 'requestChild', [ this ] ) .rows( $.map(loaded.childRows, function (id){
}) return id.replace(/:/g, '\\:')
}) )
.every( function () {
_fnCallbackFire( context, null, 'requestChild', [ this ] )
});
} }
}) });
var __details_add = function ( ctx, row, data, klass ) var __details_add = function ( ctx, row, data, klass )
{ {
@ -8393,6 +8409,15 @@
}; };
// Make state saving of child row details async to allow them to be batch processed
var __details_state = DataTable.util.throttle(
function (ctx) {
_fnSaveState( ctx[0] )
},
500
);
var __details_remove = function ( api, idx ) var __details_remove = function ( api, idx )
{ {
var ctx = api.context; var ctx = api.context;
@ -8406,7 +8431,7 @@
row._detailsShow = undefined; row._detailsShow = undefined;
row._details = undefined; row._details = undefined;
$( row.nTr ).removeClass( 'dt-hasChild' ); $( row.nTr ).removeClass( 'dt-hasChild' );
_fnSaveState( ctx[0] ); __details_state( ctx );
} }
} }
}; };
@ -8433,7 +8458,7 @@
_fnCallbackFire( ctx[0], null, 'childRow', [ show, api.row( api[0] ) ] ) _fnCallbackFire( ctx[0], null, 'childRow', [ show, api.row( api[0] ) ] )
__details_events( ctx[0] ); __details_events( ctx[0] );
_fnSaveState( ctx[0] ); __details_state( ctx );
} }
} }
}; };
@ -8444,7 +8469,7 @@
var api = new _Api( settings ); var api = new _Api( settings );
var namespace = '.dt.DT_details'; var namespace = '.dt.DT_details';
var drawEvent = 'draw'+namespace; var drawEvent = 'draw'+namespace;
var colvisEvent = 'column-visibility'+namespace; var colvisEvent = 'column-sizing'+namespace;
var destroyEvent = 'destroy'+namespace; var destroyEvent = 'destroy'+namespace;
var data = settings.aoData; var data = settings.aoData;
@ -9496,7 +9521,6 @@
remove = remove || false; remove = remove || false;
return this.iterator( 'table', function ( settings ) { return this.iterator( 'table', function ( settings ) {
var orig = settings.nTableWrapper.parentNode;
var classes = settings.oClasses; var classes = settings.oClasses;
var table = settings.nTable; var table = settings.nTable;
var tbody = settings.nTBody; var tbody = settings.nTBody;
@ -9551,6 +9575,8 @@
jqTbody.children().detach(); jqTbody.children().detach();
jqTbody.append( rows ); jqTbody.append( rows );
var orig = settings.nTableWrapper.parentNode;
// Remove the DataTables generated nodes, events and classes // Remove the DataTables generated nodes, events and classes
var removedMethod = remove ? 'remove' : 'detach'; var removedMethod = remove ? 'remove' : 'detach';
jqTable[ removedMethod ](); jqTable[ removedMethod ]();
@ -9635,7 +9661,7 @@
} }
return resolved.replace( '%d', plural ); // nb: plural might be undefined, return resolved.replace( '%d', plural ); // nb: plural might be undefined,
} ); } );
/** /**
* Version string for plug-ins to check compatibility. Allowed format is * Version string for plug-ins to check compatibility. Allowed format is
* `a.b.c-d` where: a:int, b:int, c:int, d:string(dev|beta|alpha). `d` is used * `a.b.c-d` where: a:int, b:int, c:int, d:string(dev|beta|alpha). `d` is used
@ -9644,8 +9670,8 @@
* @type string * @type string
* @default Version number * @default Version number
*/ */
DataTable.version = "1.11.3"; DataTable.version = "1.12.1";
/** /**
* Private data store, containing all of the settings objects that are * Private data store, containing all of the settings objects that are
* created for the tables on a given page. * created for the tables on a given page.
@ -9659,7 +9685,7 @@
* @private * @private
*/ */
DataTable.settings = []; DataTable.settings = [];
/** /**
* Object models container, for the various models that DataTables has * Object models container, for the various models that DataTables has
* available to it. These models define the objects that are used to hold * available to it. These models define the objects that are used to hold
@ -11849,7 +11875,6 @@
* Text which is displayed when the table is processing a user action * Text which is displayed when the table is processing a user action
* (usually a sort command or similar). * (usually a sort command or similar).
* @type string * @type string
* @default Processing...
* *
* @dtopt Language * @dtopt Language
* @name DataTable.defaults.language.processing * @name DataTable.defaults.language.processing
@ -11863,7 +11888,7 @@
* } ); * } );
* } ); * } );
*/ */
"sProcessing": "Processing...", "sProcessing": "",
/** /**
@ -14017,7 +14042,7 @@
*/ */
"rowId": null "rowId": null
}; };
/** /**
* Extension object for DataTables that is used to provide all extension * Extension object for DataTables that is used to provide all extension
* options. * options.
@ -14069,7 +14094,7 @@
* *
* @type string * @type string
*/ */
build:"bs5/dt-1.11.3", build:"bs5/dt-1.12.1",
/** /**
@ -15115,6 +15140,213 @@
d; d;
}; };
// Common logic for moment, luxon or a date action
function __mld( dt, momentFn, luxonFn, dateFn, arg1 ) {
if (window.moment) {
return dt[momentFn]( arg1 );
}
else if (window.luxon) {
return dt[luxonFn]( arg1 );
}
return dateFn ? dt[dateFn]( arg1 ) : dt;
}
var __mlWarning = false;
function __mldObj (d, format, locale) {
var dt;
if (window.moment) {
dt = window.moment.utc( d, format, locale, true );
if (! dt.isValid()) {
return null;
}
}
else if (window.luxon) {
dt = format
? window.luxon.DateTime.fromFormat( d, format )
: window.luxon.DateTime.fromISO( d );
if (! dt.isValid) {
return null;
}
dt.setLocale(locale);
}
else if (! format) {
// No format given, must be ISO
dt = new Date(d);
}
else {
if (! __mlWarning) {
alert('DataTables warning: Formatted date without Moment.js or Luxon - https://datatables.net/tn/17');
}
__mlWarning = true;
}
return dt;
}
// Wrapper for date, datetime and time which all operate the same way with the exception of
// the output string for auto locale support
function __mlHelper (localeString) {
return function ( from, to, locale, def ) {
// Luxon and Moment support
// Argument shifting
if ( arguments.length === 0 ) {
locale = 'en';
to = null; // means toLocaleString
from = null; // means iso8601
}
else if ( arguments.length === 1 ) {
locale = 'en';
to = from;
from = null;
}
else if ( arguments.length === 2 ) {
locale = to;
to = from;
from = null;
}
var typeName = 'datetime-' + to;
// Add type detection and sorting specific to this date format - we need to be able to identify
// date type columns as such, rather than as numbers in extensions. Hence the need for this.
if (! DataTable.ext.type.order[typeName]) {
// The renderer will give the value to type detect as the type!
DataTable.ext.type.detect.unshift(function (d) {
return d === typeName ? typeName : false;
});
// The renderer gives us Moment, Luxon or Date obects for the sorting, all of which have a
// `valueOf` which gives milliseconds epoch
DataTable.ext.type.order[typeName + '-asc'] = function (a, b) {
var x = a.valueOf();
var y = b.valueOf();
return x === y
? 0
: x < y
? -1
: 1;
}
DataTable.ext.type.order[typeName + '-desc'] = function (a, b) {
var x = a.valueOf();
var y = b.valueOf();
return x === y
? 0
: x > y
? -1
: 1;
}
}
return function ( d, type ) {
// Allow for a default value
if (d === null || d === undefined) {
if (def === '--now') {
// We treat everything as UTC further down, so no changes are
// made, as such need to get the local date / time as if it were
// UTC
var local = new Date();
d = new Date( Date.UTC(
local.getFullYear(), local.getMonth(), local.getDate(),
local.getHours(), local.getMinutes(), local.getSeconds()
) );
}
else {
d = '';
}
}
if (type === 'type') {
// Typing uses the type name for fast matching
return typeName;
}
if (d === '') {
return type !== 'sort'
? ''
: __mldObj('0000-01-01 00:00:00', null, locale);
}
// Shortcut. If `from` and `to` are the same, we are using the renderer to
// format for ordering, not display - its already in the display format.
if ( to !== null && from === to && type !== 'sort' && type !== 'type' && ! (d instanceof Date) ) {
return d;
}
var dt = __mldObj(d, from, locale);
if (dt === null) {
return d;
}
if (type === 'sort') {
return dt;
}
var formatted = to === null
? __mld(dt, 'toDate', 'toJSDate', '')[localeString]()
: __mld(dt, 'format', 'toFormat', 'toISOString', to);
// XSS protection
return type === 'display' ?
__htmlEscapeEntities( formatted ) :
formatted;
};
}
}
// Based on locale, determine standard number formatting
// Fallback for legacy browsers is US English
var __thousands = ',';
var __decimal = '.';
if (Intl) {
try {
var num = new Intl.NumberFormat().formatToParts(100000.1);
for (var i=0 ; i<num.length ; i++) {
if (num[i].type === 'group') {
__thousands = num[i].value;
}
else if (num[i].type === 'decimal') {
__decimal = num[i].value;
}
}
}
catch (e) {
// noop
}
}
// Formatted date time detection - use by declaring the formats you are going to use
DataTable.datetime = function ( format, locale ) {
var typeName = 'datetime-detect-' + format;
if (! locale) {
locale = 'en';
}
if (! DataTable.ext.type.order[typeName]) {
DataTable.ext.type.detect.unshift(function (d) {
var dt = __mldObj(d, format, locale);
return d === '' || dt ? typeName : false;
});
DataTable.ext.type.order[typeName + '-pre'] = function (d) {
return __mldObj(d, format, locale) || 0;
}
}
}
/** /**
* Helpers for `columns.render`. * Helpers for `columns.render`.
* *
@ -15142,13 +15374,29 @@
* @namespace * @namespace
*/ */
DataTable.render = { DataTable.render = {
date: __mlHelper('toLocaleDateString'),
datetime: __mlHelper('toLocaleString'),
time: __mlHelper('toLocaleTimeString'),
number: function ( thousands, decimal, precision, prefix, postfix ) { number: function ( thousands, decimal, precision, prefix, postfix ) {
// Auto locale detection
if (thousands === null || thousands === undefined) {
thousands = __thousands;
}
if (decimal === null || decimal === undefined) {
decimal = __decimal;
}
return { return {
display: function ( d ) { display: function ( d ) {
if ( typeof d !== 'number' && typeof d !== 'string' ) { if ( typeof d !== 'number' && typeof d !== 'string' ) {
return d; return d;
} }
if (d === '' || d === null) {
return d;
}
var negative = d < 0 ? '-' : ''; var negative = d < 0 ? '-' : '';
var flo = parseFloat( d ); var flo = parseFloat( d );
@ -15317,29 +15565,29 @@
// added to prevent errors // added to prevent errors
} ); } );
// jQuery access // jQuery access
$.fn.dataTable = DataTable; $.fn.dataTable = DataTable;
// Provide access to the host jQuery object (circular reference) // Provide access to the host jQuery object (circular reference)
DataTable.$ = $; DataTable.$ = $;
// Legacy aliases // Legacy aliases
$.fn.dataTableSettings = DataTable.settings; $.fn.dataTableSettings = DataTable.settings;
$.fn.dataTableExt = DataTable.ext; $.fn.dataTableExt = DataTable.ext;
// With a capital `D` we return a DataTables API instance rather than a // With a capital `D` we return a DataTables API instance rather than a
// jQuery object // jQuery object
$.fn.DataTable = function ( opts ) { $.fn.DataTable = function ( opts ) {
return $(this).dataTable( opts ).api(); return $(this).dataTable( opts ).api();
}; };
// All properties that are available to $.fn.dataTable should also be // All properties that are available to $.fn.dataTable should also be
// available on $.fn.DataTable // available on $.fn.DataTable
$.each( DataTable, function ( prop, val ) { $.each( DataTable, function ( prop, val ) {
$.fn.DataTable[ prop ] = val; $.fn.DataTable[ prop ] = val;
} ); } );
return DataTable; return DataTable;
})); }));

7
src/static/templates/admin/base.hbs

@ -20,6 +20,13 @@
width: auto; width: auto;
margin: -5px 0 0 0; margin: -5px 0 0 0;
} }
/* Special alert-row class to use Bootstrap v5.2+ variable colors */
.alert-row {
--bs-alert-border: 1px solid var(--bs-alert-border-color);
color: var(--bs-alert-color);
background-color: var(--bs-alert-bg);
border: var(--bs-alert-border);
}
</style> </style>
<script src="/vw_static/identicon.js"></script> <script src="/vw_static/identicon.js"></script>
<script> <script>

4
src/static/templates/admin/login.hbs

@ -13,9 +13,9 @@
<small>Please provide it below:</small> <small>Please provide it below:</small>
<form class="form-inline" method="post"> <form class="form-inline" method="post">
<input type="password" class="form-control w-50 mr-2" name="token" placeholder="Enter admin token"> <input type="password" class="form-control w-50 mr-2" name="token" placeholder="Enter admin token" autofocus="autofocus">
<button type="submit" class="btn btn-primary">Enter</button> <button type="submit" class="btn btn-primary">Enter</button>
</form> </form>
</div> </div>
</div> </div>
</main> </main>

63
src/static/templates/admin/settings.hbs

@ -5,7 +5,7 @@
<div class="small text-white mb-3"> <div class="small text-white mb-3">
<span class="font-weight-bolder">NOTE:</span> The settings here override the environment variables. Once saved, it's recommended to stop setting them to avoid confusion.<br> <span class="font-weight-bolder">NOTE:</span> The settings here override the environment variables. Once saved, it's recommended to stop setting them to avoid confusion.<br>
This does not apply to the read-only section, which can only be set via environment variables.<br> This does not apply to the read-only section, which can only be set via environment variables.<br>
Settings which are overridden are shown with <span class="is-overridden-true">double underscores</span>. Settings which are overridden are shown with <span class="is-overridden-true alert-row px-1">a yellow colored background</span>.
</div> </div>
<form class="form needs-validation" id="config-form" onsubmit="saveConfig(); return false;" novalidate> <form class="form needs-validation" id="config-form" onsubmit="saveConfig(); return false;" novalidate>
@ -16,7 +16,7 @@
<div id="g_{{group}}" class="card-body collapse"> <div id="g_{{group}}" class="card-body collapse">
{{#each elements}} {{#each elements}}
{{#if editable}} {{#if editable}}
<div class="row my-2 align-items-center is-overridden-{{overridden}}" title="[{{name}}] {{doc.description}}"> <div class="row my-2 align-items-center is-overridden-{{overridden}} alert-row" title="[{{name}}] {{doc.description}}">
{{#case type "text" "number" "password"}} {{#case type "text" "number" "password"}}
<label for="input_{{name}}" class="col-sm-3 col-form-label">{{doc.name}}</label> <label for="input_{{name}}" class="col-sm-3 col-form-label">{{doc.name}}</label>
<div class="col-sm-8"> <div class="col-sm-8">
@ -71,16 +71,25 @@
{{#each config}} {{#each config}}
{{#each elements}} {{#each elements}}
{{#unless editable}} {{#unless editable}}
<div class="row my-2 align-items-center" title="[{{name}}] {{doc.description}}"> <div class="row my-2 align-items-center alert-row" title="[{{name}}] {{doc.description}}">
{{#case type "text" "number" "password"}} {{#case type "text" "number" "password"}}
<label for="input_{{name}}" class="col-sm-3 col-form-label">{{doc.name}}</label> <label for="input_{{name}}" class="col-sm-3 col-form-label">{{doc.name}}</label>
<div class="col-sm-8"> <div class="col-sm-8">
<div class="input-group"> <div class="input-group">
<input readonly class="form-control" id="input_{{name}}" type="{{type}}" {{!--
value="{{value}}" {{#if default}} placeholder="Default: {{default}}" {{/if}}> Also set the database_url input as password here.
{{#case type "password"}} If we would set it to password in config.rs it will not be character masked for the support string.
And sometimes this is more useful for providing support than just 3 asterisk.
--}}
{{#if (eq name "database_url")}}
<input readonly class="form-control" id="input_{{name}}" type="password" value="{{value}}" {{#if default}} placeholder="Default: {{default}}" {{/if}}>
<button class="btn btn-outline-secondary" type="button" onclick="toggleVis('input_{{name}}');">Show/hide</button> <button class="btn btn-outline-secondary" type="button" onclick="toggleVis('input_{{name}}');">Show/hide</button>
{{/case}} {{else}}
<input readonly class="form-control" id="input_{{name}}" type="{{type}}" value="{{value}}" {{#if default}} placeholder="Default: {{default}}" {{/if}}>
{{#case type "password"}}
<button class="btn btn-outline-secondary" type="button" onclick="toggleVis('input_{{name}}');">Show/hide</button>
{{/case}}
{{/if}}
</div> </div>
</div> </div>
{{/case}} {{/case}}
@ -134,7 +143,9 @@
} }
.is-overridden-true { .is-overridden-true {
text-decoration: underline double; --bs-alert-color: #664d03;
--bs-alert-bg: #fff3cd;
--bs-alert-border-color: #ffecb5;
} }
</style> </style>
@ -238,19 +249,45 @@
return Array.from(form).some(el => 'origValue' in el.dataset && ( el.dataset.origValue !== el.value)); return Array.from(form).some(el => 'origValue' in el.dataset && ( el.dataset.origValue !== el.value));
} }
// Trigger Form Change Detection // This function will prevent submitting a from when someone presses enter.
function preventFormSubmitOnEnter(form) {
form.onkeypress = function(e) {
let key = e.charCode || e.keyCode || 0;
if (key == 13) {
e.preventDefault();
}
}
}
// Initialize Form Change Detection
const config_form = document.getElementById('config-form'); const config_form = document.getElementById('config-form');
initChangeDetection(config_form); initChangeDetection(config_form);
// Prevent enter to submitting the form and save the config.
// Users need to really click on save, this also to prevent accidental submits.
preventFormSubmitOnEnter(config_form);
// This function will hook into the smtp-test-email input field and will call the smtpTest() function when enter is pressed.
function submitTestEmailOnEnter() {
const smtp_test_email_input = document.getElementById('smtp-test-email');
smtp_test_email_input.onkeypress = function(e) {
let key = e.charCode || e.keyCode || 0;
if (key == 13) {
e.preventDefault();
smtpTest();
}
}
}
submitTestEmailOnEnter();
// Colorize some settings which are high risk // Colorize some settings which are high risk
const risk_items = document.getElementsByClassName('col-form-label'); function colorRiskSettings() {
function colorRiskSettings(risk_el) { const risk_items = document.getElementsByClassName('col-form-label');
Array.from(risk_el).forEach((el) => { Array.from(risk_items).forEach((el) => {
if (el.innerText.toLowerCase().includes('risks') ) { if (el.innerText.toLowerCase().includes('risks') ) {
el.parentElement.className += ' alert-danger' el.parentElement.className += ' alert-danger'
} }
}); });
} }
colorRiskSettings(risk_items); colorRiskSettings();
</script> </script>

172
src/util.rs

@ -5,19 +5,22 @@ use std::io::Cursor;
use rocket::{ use rocket::{
fairing::{Fairing, Info, Kind}, fairing::{Fairing, Info, Kind},
http::{ContentType, Header, HeaderMap, Method, RawStr, Status}, http::{ContentType, Header, HeaderMap, Method, Status},
request::FromParam, request::FromParam,
response::{self, Responder}, response::{self, Responder},
Data, Request, Response, Rocket, Data, Orbit, Request, Response, Rocket,
}; };
use std::thread::sleep; use tokio::{
use std::time::Duration; runtime::Handle,
time::{sleep, Duration},
};
use crate::CONFIG; use crate::CONFIG;
pub struct AppHeaders(); pub struct AppHeaders();
#[rocket::async_trait]
impl Fairing for AppHeaders { impl Fairing for AppHeaders {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
@ -26,20 +29,57 @@ impl Fairing for AppHeaders {
} }
} }
fn on_response(&self, _req: &Request, res: &mut Response) { async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()"); res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()");
res.set_raw_header("Referrer-Policy", "same-origin"); res.set_raw_header("Referrer-Policy", "same-origin");
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
res.set_raw_header("X-Content-Type-Options", "nosniff"); res.set_raw_header("X-Content-Type-Options", "nosniff");
res.set_raw_header("X-XSS-Protection", "1; mode=block"); // Obsolete in modern browsers, unsafe (XS-Leak), and largely replaced by CSP
let csp = format!( res.set_raw_header("X-XSS-Protection", "0");
let req_uri_path = req.uri().path();
// Do not send the Content-Security-Policy (CSP) Header and X-Frame-Options for the *-connector.html files.
// This can cause issues when some MFA requests needs to open a popup or page within the clients like WebAuthn, or Duo.
// This is the same behaviour as upstream Bitwarden.
if !req_uri_path.ends_with("connector.html") {
// Check if we are requesting an admin page, if so, allow unsafe-inline for scripts.
// TODO: In the future maybe we need to see if we can generate a sha256 hash or have no scripts inline at all.
let admin_path = format!("{}/admin", CONFIG.domain_path());
let mut script_src = "";
if req_uri_path.starts_with(admin_path.as_str()) {
script_src = " 'unsafe-inline'";
}
// # Frame Ancestors:
// Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb // Chrome Web Store: https://chrome.google.com/webstore/detail/bitwarden-free-password-m/nngceckbapebfimnlniiiahkandclblb
// Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US // Edge Add-ons: https://microsoftedge.microsoft.com/addons/detail/bitwarden-free-password/jbkfoedolllekgbhcbcoahefnbanhhlh?hl=en-US
// Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/ // Firefox Browser Add-ons: https://addons.mozilla.org/en-US/firefox/addon/bitwarden-password-manager/
"frame-ancestors 'self' chrome-extension://nngceckbapebfimnlniiiahkandclblb chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh moz-extension://* {};", // # img/child/frame src:
CONFIG.allowed_iframe_ancestors() // Have I Been Pwned and Gravator to allow those calls to work.
); // # Connect src:
res.set_raw_header("Content-Security-Policy", csp); // Leaked Passwords check: api.pwnedpasswords.com
// 2FA/MFA Site check: 2fa.directory
// # Mail Relay: https://bitwarden.com/blog/add-privacy-and-security-using-email-aliases-with-bitwarden/
// app.simplelogin.io, app.anonaddy.com, relay.firefox.com
let csp = format!(
"default-src 'self'; \
script-src 'self'{script_src}; \
style-src 'self' 'unsafe-inline'; \
img-src 'self' data: https://haveibeenpwned.com/ https://www.gravatar.com {icon_service_csp}; \
child-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
frame-src 'self' https://*.duosecurity.com https://*.duofederal.com; \
connect-src 'self' https://api.pwnedpasswords.com/range/ https://2fa.directory/api/ https://app.simplelogin.io/api/ https://app.anonaddy.com/api/ https://relay.firefox.com/api/; \
object-src 'self' blob:; \
frame-ancestors 'self' chrome-extension://nngceckbapebfimnlniiiahkandclblb chrome-extension://jbkfoedolllekgbhcbcoahefnbanhhlh moz-extension://* {allowed_iframe_ancestors};",
icon_service_csp=CONFIG._icon_service_csp(),
allowed_iframe_ancestors=CONFIG.allowed_iframe_ancestors()
);
res.set_raw_header("Content-Security-Policy", csp);
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
} else {
// It looks like this header get's set somewhere else also, make sure this is not sent for these files, it will cause MFA issues.
res.remove_header("X-Frame-Options");
}
// Disable cache unless otherwise specified // Disable cache unless otherwise specified
if !res.headers().contains("cache-control") { if !res.headers().contains("cache-control") {
@ -51,7 +91,7 @@ impl Fairing for AppHeaders {
pub struct Cors(); pub struct Cors();
impl Cors { impl Cors {
fn get_header(headers: &HeaderMap, name: &str) -> String { fn get_header(headers: &HeaderMap<'_>, name: &str) -> String {
match headers.get_one(name) { match headers.get_one(name) {
Some(h) => h.to_string(), Some(h) => h.to_string(),
_ => "".to_string(), _ => "".to_string(),
@ -60,7 +100,7 @@ impl Cors {
// Check a request's `Origin` header against the list of allowed origins. // Check a request's `Origin` header against the list of allowed origins.
// If a match exists, return it. Otherwise, return None. // If a match exists, return it. Otherwise, return None.
fn get_allowed_origin(headers: &HeaderMap) -> Option<String> { fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option<String> {
let origin = Cors::get_header(headers, "Origin"); let origin = Cors::get_header(headers, "Origin");
let domain_origin = CONFIG.domain_origin(); let domain_origin = CONFIG.domain_origin();
let safari_extension_origin = "file://"; let safari_extension_origin = "file://";
@ -72,6 +112,7 @@ impl Cors {
} }
} }
#[rocket::async_trait]
impl Fairing for Cors { impl Fairing for Cors {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
@ -80,7 +121,7 @@ impl Fairing for Cors {
} }
} }
fn on_response(&self, request: &Request, response: &mut Response) { async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
let req_headers = request.headers(); let req_headers = request.headers();
if let Some(origin) = Cors::get_allowed_origin(req_headers) { if let Some(origin) = Cors::get_allowed_origin(req_headers) {
@ -97,7 +138,7 @@ impl Fairing for Cors {
response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
response.set_status(Status::Ok); response.set_status(Status::Ok);
response.set_header(ContentType::Plain); response.set_header(ContentType::Plain);
response.set_sized_body(Cursor::new("")); response.set_sized_body(Some(0), Cursor::new(""));
} }
} }
} }
@ -134,32 +175,28 @@ impl<R> Cached<R> {
} }
} }
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> { impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
fn respond_to(self, req: &Request) -> response::Result<'r> { fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
let mut res = self.response.respond_to(request)?;
let cache_control_header = if self.is_immutable { let cache_control_header = if self.is_immutable {
format!("public, immutable, max-age={}", self.ttl) format!("public, immutable, max-age={}", self.ttl)
} else { } else {
format!("public, max-age={}", self.ttl) format!("public, max-age={}", self.ttl)
}; };
res.set_raw_header("Cache-Control", cache_control_header);
let time_now = chrono::Local::now(); let time_now = chrono::Local::now();
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
match self.response.respond_to(req) { res.set_raw_header("Expires", format_datetime_http(&expiry_time));
Ok(mut res) => { Ok(res)
res.set_raw_header("Cache-Control", cache_control_header);
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
Ok(res)
}
e @ Err(_) => e,
}
} }
} }
pub struct SafeString(String); pub struct SafeString(String);
impl std::fmt::Display for SafeString { impl std::fmt::Display for SafeString {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f) self.0.fmt(f)
} }
} }
@ -175,11 +212,9 @@ impl<'r> FromParam<'r> for SafeString {
type Error = (); type Error = ();
#[inline(always)] #[inline(always)]
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> { fn from_param(param: &'r str) -> Result<Self, Self::Error> {
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?; if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(SafeString(param.to_string()))
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(SafeString(s))
} else { } else {
Err(()) Err(())
} }
@ -193,15 +228,16 @@ const LOGGED_ROUTES: [&str; 6] =
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
pub struct BetterLogging(pub bool); pub struct BetterLogging(pub bool);
#[rocket::async_trait]
impl Fairing for BetterLogging { impl Fairing for BetterLogging {
fn info(&self) -> Info { fn info(&self) -> Info {
Info { Info {
name: "Better Logging", name: "Better Logging",
kind: Kind::Launch | Kind::Request | Kind::Response, kind: Kind::Liftoff | Kind::Request | Kind::Response,
} }
} }
fn on_launch(&self, rocket: &Rocket) { async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
if self.0 { if self.0 {
info!(target: "routes", "Routes loaded:"); info!(target: "routes", "Routes loaded:");
let mut routes: Vec<_> = rocket.routes().collect(); let mut routes: Vec<_> = rocket.routes().collect();
@ -225,34 +261,36 @@ impl Fairing for BetterLogging {
info!(target: "start", "Rocket has launched from {}", addr); info!(target: "start", "Rocket has launched from {}", addr);
} }
fn on_request(&self, request: &mut Request<'_>, _data: &Data) { async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
let method = request.method(); let method = request.method();
if !self.0 && method == Method::Options { if !self.0 && method == Method::Options {
return; return;
} }
let uri = request.uri(); let uri = request.uri();
let uri_path = uri.path(); let uri_path = uri.path();
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path); let uri_path_str = uri_path.url_decode_lossy();
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
match uri.query() { match uri.query() {
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]), Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
None => info!(target: "request", "{} {}", method, uri_path), None => info!(target: "request", "{} {}", method, uri_path_str),
}; };
} }
} }
fn on_response(&self, request: &Request, response: &mut Response) { async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
if !self.0 && request.method() == Method::Options { if !self.0 && request.method() == Method::Options {
return; return;
} }
let uri_path = request.uri().path(); let uri_path = request.uri().path();
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path); let uri_path_str = uri_path.url_decode_lossy();
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
let status = response.status(); let status = response.status();
if let Some(route) = request.route() { if let Some(ref route) = request.route() {
info!(target: "response", "{} => {} {}", route, status.code, status.reason) info!(target: "response", "{} => {}", route, status)
} else { } else {
info!(target: "response", "{} {}", status.code, status.reason) info!(target: "response", "{}", status)
} }
} }
} }
@ -263,7 +301,7 @@ impl Fairing for BetterLogging {
// //
use std::{ use std::{
fs::{self, File}, fs::{self, File},
io::{Read, Result as IOResult}, io::Result as IOResult,
path::Path, path::Path,
}; };
@ -271,15 +309,6 @@ pub fn file_exists(path: &str) -> bool {
Path::new(path).exists() Path::new(path).exists()
} }
pub fn read_file(path: &str) -> IOResult<Vec<u8>> {
let mut contents: Vec<u8> = Vec::new();
let mut file = File::open(Path::new(path))?;
file.read_to_end(&mut contents)?;
Ok(contents)
}
pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> { pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error> {
use std::io::Write; use std::io::Write;
let mut f = File::create(path)?; let mut f = File::create(path)?;
@ -288,15 +317,6 @@ pub fn write_file(path: &str, content: &[u8]) -> Result<(), crate::error::Error>
Ok(()) Ok(())
} }
pub fn read_file_string(path: &str) -> IOResult<String> {
let mut contents = String::new();
let mut file = File::open(Path::new(path))?;
file.read_to_string(&mut contents)?;
Ok(contents)
}
pub fn delete_file(path: &str) -> IOResult<()> { pub fn delete_file(path: &str) -> IOResult<()> {
let res = fs::remove_file(path); let res = fs::remove_file(path);
@ -501,7 +521,7 @@ struct UpCaseVisitor;
impl<'de> Visitor<'de> for UpCaseVisitor { impl<'de> Visitor<'de> for UpCaseVisitor {
type Value = Value; type Value = Value;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("an object or an array") formatter.write_str("an object or an array")
} }
@ -582,14 +602,13 @@ where
if tries >= max_tries { if tries >= max_tries {
return err; return err;
} }
Handle::current().block_on(async move { sleep(Duration::from_millis(500)).await });
sleep(Duration::from_millis(500));
} }
} }
} }
} }
pub fn retry_db<F, T, E>(func: F, max_tries: u32) -> Result<T, E> pub async fn retry_db<F, T, E>(func: F, max_tries: u32) -> Result<T, E>
where where
F: Fn() -> Result<T, E>, F: Fn() -> Result<T, E>,
E: std::error::Error, E: std::error::Error,
@ -608,19 +627,22 @@ where
warn!("Can't connect to database, retrying: {:?}", e); warn!("Can't connect to database, retrying: {:?}", e);
sleep(Duration::from_millis(1_000)); sleep(Duration::from_millis(1_000)).await;
} }
} }
} }
} }
use reqwest::{ use reqwest::{header, Client, ClientBuilder};
blocking::{Client, ClientBuilder},
header,
};
pub fn get_reqwest_client() -> Client { pub fn get_reqwest_client() -> Client {
get_reqwest_client_builder().build().expect("Failed to build client") match get_reqwest_client_builder().build() {
Ok(client) => client,
Err(e) => {
error!("Possible trust-dns error, trying with trust-dns disabled: '{e}'");
get_reqwest_client_builder().trust_dns(false).build().expect("Failed to build client")
}
}
} }
pub fn get_reqwest_client_builder() -> ClientBuilder { pub fn get_reqwest_client_builder() -> ClientBuilder {

Loading…
Cancel
Save