Browse Source

refactor: use RwLock on DbPool instead Mutex

pull/5037/head
Dmitry Ulyanov 7 months ago
parent
commit
2fcc353db0
  1. 6
      src/api/core/accounts.rs
  2. 6
      src/api/core/ciphers.rs
  3. 10
      src/api/core/emergency_access.rs
  4. 6
      src/api/core/events.rs
  5. 7
      src/api/core/sends.rs
  6. 6
      src/api/core/two_factor/duo_oidc.rs
  7. 6
      src/api/core/two_factor/mod.rs
  8. 6
      src/db/mod.rs
  9. 10
      src/main.rs

6
src/api/core/accounts.rs

@ -4,7 +4,7 @@ use crate::db::DbPool;
use chrono::{SecondsFormat, Utc}; use chrono::{SecondsFormat, Utc};
use rocket::serde::json::Json; use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::{ use crate::{
api::{ api::{
@ -1285,9 +1285,9 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
}))) })))
} }
pub async fn purge_auth_requests(pool: Arc<Mutex<DbPool>>) { pub async fn purge_auth_requests(pool: Arc<RwLock<DbPool>>) {
debug!("Purging auth requests"); debug!("Purging auth requests");
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await; AuthRequest::purge_expired_auth_requests(&mut conn).await;
} else { } else {
error!("Failed to get DB connection while purging trashed ciphers") error!("Failed to get DB connection while purging trashed ciphers")

6
src/api/core/ciphers.rs

@ -10,7 +10,7 @@ use rocket::{
Route, Route,
}; };
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::util::NumberOrString; use crate::util::NumberOrString;
use crate::{ use crate::{
@ -90,9 +90,9 @@ pub fn routes() -> Vec<Route> {
] ]
} }
pub async fn purge_trashed_ciphers(pool: Arc<Mutex<DbPool>>) { pub async fn purge_trashed_ciphers(pool: Arc<RwLock<DbPool>>) {
debug!("Purging trashed ciphers"); debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
Cipher::purge_trash(&mut conn).await; Cipher::purge_trash(&mut conn).await;
} else { } else {
error!("Failed to get DB connection while purging trashed ciphers") error!("Failed to get DB connection while purging trashed ciphers")

10
src/api/core/emergency_access.rs

@ -3,7 +3,7 @@ use std::sync::Arc;
use chrono::{TimeDelta, Utc}; use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route}; use rocket::{serde::json::Json, Route};
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::{ use crate::{
api::{ api::{
@ -732,13 +732,13 @@ fn check_emergency_access_enabled() -> EmptyResult {
Ok(()) Ok(())
} }
pub async fn emergency_request_timeout_job(pool: Arc<Mutex<DbPool>>) { pub async fn emergency_request_timeout_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_request_timeout_job"); debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() { if !CONFIG.emergency_access_allowed() {
return; return;
} }
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {
@ -787,13 +787,13 @@ pub async fn emergency_request_timeout_job(pool: Arc<Mutex<DbPool>>) {
} }
} }
pub async fn emergency_notification_reminder_job(pool: Arc<Mutex<DbPool>>) { pub async fn emergency_notification_reminder_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_notification_reminder_job"); debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() { if !CONFIG.emergency_access_allowed() {
return; return;
} }
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await; let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() { if emergency_access_list.is_empty() {

6
src/api/core/events.rs

@ -3,7 +3,7 @@ use std::{net::IpAddr, sync::Arc};
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route}; use rocket::{form::FromForm, serde::json::Json, Route};
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::{ use crate::{
api::{EmptyResult, JsonResult}, api::{EmptyResult, JsonResult},
@ -321,14 +321,14 @@ async fn _log_event(
event.save(conn).await.unwrap_or(()); event.save(conn).await.unwrap_or(());
} }
pub async fn event_cleanup_job(pool: Arc<Mutex<DbPool>>) { pub async fn event_cleanup_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start events cleanup job"); debug!("Start events cleanup job");
if CONFIG.events_days_retain().is_none() { if CONFIG.events_days_retain().is_none() {
debug!("events_days_retain is not configured, abort"); debug!("events_days_retain is not configured, abort");
return; return;
} }
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
Event::clean_events(&mut conn).await.ok(); Event::clean_events(&mut conn).await.ok();
} else { } else {
error!("Failed to get DB connection while trying to cleanup the events table") error!("Failed to get DB connection while trying to cleanup the events table")

7
src/api/core/sends.rs

@ -8,7 +8,8 @@ use rocket::fs::NamedFile;
use rocket::fs::TempFile; use rocket::fs::TempFile;
use rocket::serde::json::Json; use rocket::serde::json::Json;
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; // use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{ use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType}, api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
@ -40,9 +41,9 @@ pub fn routes() -> Vec<rocket::Route> {
] ]
} }
pub async fn purge_sends(pool: Arc<Mutex<DbPool>>) { pub async fn purge_sends(pool: Arc<RwLock<DbPool>>) {
debug!("Purging sends"); debug!("Purging sends");
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
Send::purge(&mut conn).await; Send::purge(&mut conn).await;
} else { } else {
error!("Failed to get DB connection while purging sends") error!("Failed to get DB connection while purging sends")

6
src/api/core/two_factor/duo_oidc.rs

@ -5,7 +5,7 @@ use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256}; use ring::digest::{digest, Digest, SHA512_256};
use serde::Serialize; use serde::Serialize;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::{ use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult}, api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
@ -346,9 +346,9 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
} }
// Task to clean up expired Duo authentication contexts that may have accumulated in the database. // Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: Arc<Mutex<DbPool>>) { pub async fn purge_duo_contexts(pool: Arc<RwLock<DbPool>>) {
debug!("Purging Duo authentication contexts"); debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.lock().await.get().await { if let Ok(mut conn) = pool.read().await.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await; TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await;
} else { } else {
error!("Failed to get DB connection while purging expired Duo authentications") error!("Failed to get DB connection while purging expired Duo authentications")

6
src/api/core/two_factor/mod.rs

@ -5,7 +5,7 @@ use data_encoding::BASE32;
use rocket::serde::json::Json; use rocket::serde::json::Json;
use rocket::Route; use rocket::Route;
use serde_json::Value; use serde_json::Value;
use tokio::sync::Mutex; use tokio::sync::RwLock;
use crate::{ use crate::{
api::{ api::{
@ -247,14 +247,14 @@ pub async fn enforce_2fa_policy_for_org(
Ok(()) Ok(())
} }
pub async fn send_incomplete_2fa_notifications(pool: Arc<Mutex<DbPool>>) { pub async fn send_incomplete_2fa_notifications(pool: Arc<RwLock<DbPool>>) {
debug!("Sending notifications for incomplete 2FA logins"); debug!("Sending notifications for incomplete 2FA logins");
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() { if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return; return;
} }
let mut conn = match pool.lock().await.get().await { let mut conn = match pool.read().await.get().await {
Ok(conn) => conn, Ok(conn) => conn,
_ => { _ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()"); error!("Failed to get DB connection in send_incomplete_2fa_notifications()");

6
src/db/mod.rs

@ -12,7 +12,7 @@ use rocket::{
}; };
use tokio::{ use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore}, sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore},
time::timeout, time::timeout,
}; };
@ -417,8 +417,8 @@ impl<'r> FromRequest<'r> for DbConn {
type Error = (); type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> { async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<Arc<Mutex<DbPool>>>() { match request.rocket().state::<Arc<RwLock<DbPool>>>() {
Some(p) => match p.lock().await.get().await { Some(p) => match p.read().await.get().await {
Ok(dbconn) => Outcome::Success(dbconn), Ok(dbconn) => Outcome::Success(dbconn),
_ => Outcome::Error((Status::ServiceUnavailable, ())), _ => Outcome::Error((Status::ServiceUnavailable, ())),
}, },

10
src/main.rs

@ -39,7 +39,7 @@ use tokio::{
fs::File, fs::File,
io::{AsyncBufReadExt, BufReader}, io::{AsyncBufReadExt, BufReader},
signal::unix::SignalKind, signal::unix::SignalKind,
sync::Mutex, sync::RwLock,
}; };
#[macro_use] #[macro_use]
@ -83,10 +83,10 @@ async fn main() -> Result<(), Error> {
create_dir(&CONFIG.sends_folder(), "sends folder"); create_dir(&CONFIG.sends_folder(), "sends folder");
create_dir(&CONFIG.attachments_folder(), "attachments folder"); create_dir(&CONFIG.attachments_folder(), "attachments folder");
let pool = Arc::new(Mutex::new(create_db_pool().await)); let pool = Arc::new(RwLock::new(create_db_pool().await));
schedule_jobs(Arc::clone(&pool)); schedule_jobs(Arc::clone(&pool));
{ {
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.lock().await.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.read().await.get().await.unwrap()).await.unwrap();
} }
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug); let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
@ -561,7 +561,7 @@ async fn create_db_pool() -> db::DbPool {
} }
} }
async fn launch_rocket(pool: Arc<Mutex<db::DbPool>>, extra_debug: bool) -> Result<(), Error> { async fn launch_rocket(pool: Arc<RwLock<db::DbPool>>, extra_debug: bool) -> Result<(), Error> {
let basepath = &CONFIG.domain_path(); let basepath = &CONFIG.domain_path();
let mut config = rocket::Config::from(rocket::Config::figment()); let mut config = rocket::Config::from(rocket::Config::figment());
@ -624,7 +624,7 @@ async fn launch_rocket(pool: Arc<Mutex<db::DbPool>>, extra_debug: bool) -> Resul
Ok(()) Ok(())
} }
fn schedule_jobs(pool: Arc<Mutex<db::DbPool>>) { fn schedule_jobs(pool: Arc<RwLock<db::DbPool>>) {
if CONFIG.job_poll_interval_ms() == 0 { if CONFIG.job_poll_interval_ms() == 0 {
info!("Job scheduler disabled."); info!("Job scheduler disabled.");
return; return;

Loading…
Cancel
Save