Browse Source

refactor: use RwLock on DbPool instead Mutex

pull/5037/head
Dmitry Ulyanov 6 months ago
parent
commit
2fcc353db0
  1. 6
      src/api/core/accounts.rs
  2. 6
      src/api/core/ciphers.rs
  3. 10
      src/api/core/emergency_access.rs
  4. 6
      src/api/core/events.rs
  5. 7
      src/api/core/sends.rs
  6. 6
      src/api/core/two_factor/duo_oidc.rs
  7. 6
      src/api/core/two_factor/mod.rs
  8. 6
      src/db/mod.rs
  9. 10
      src/main.rs

6
src/api/core/accounts.rs

@ -4,7 +4,7 @@ use crate::db::DbPool;
use chrono::{SecondsFormat, Utc};
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{
@ -1285,9 +1285,9 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
})))
}
pub async fn purge_auth_requests(pool: Arc<Mutex<DbPool>>) {
pub async fn purge_auth_requests(pool: Arc<RwLock<DbPool>>) {
debug!("Purging auth requests");
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")

6
src/api/core/ciphers.rs

@ -10,7 +10,7 @@ use rocket::{
Route,
};
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::util::NumberOrString;
use crate::{
@ -90,9 +90,9 @@ pub fn routes() -> Vec<Route> {
]
}
pub async fn purge_trashed_ciphers(pool: Arc<Mutex<DbPool>>) {
pub async fn purge_trashed_ciphers(pool: Arc<RwLock<DbPool>>) {
debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Cipher::purge_trash(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")

10
src/api/core/emergency_access.rs

@ -3,7 +3,7 @@ use std::sync::Arc;
use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{
@ -732,13 +732,13 @@ fn check_emergency_access_enabled() -> EmptyResult {
Ok(())
}
pub async fn emergency_request_timeout_job(pool: Arc<Mutex<DbPool>>) {
pub async fn emergency_request_timeout_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() {
@ -787,13 +787,13 @@ pub async fn emergency_request_timeout_job(pool: Arc<Mutex<DbPool>>) {
}
}
pub async fn emergency_notification_reminder_job(pool: Arc<Mutex<DbPool>>) {
pub async fn emergency_notification_reminder_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() {

6
src/api/core/events.rs

@ -3,7 +3,7 @@ use std::{net::IpAddr, sync::Arc};
use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{EmptyResult, JsonResult},
@ -321,14 +321,14 @@ async fn _log_event(
event.save(conn).await.unwrap_or(());
}
pub async fn event_cleanup_job(pool: Arc<Mutex<DbPool>>) {
pub async fn event_cleanup_job(pool: Arc<RwLock<DbPool>>) {
debug!("Start events cleanup job");
if CONFIG.events_days_retain().is_none() {
debug!("events_days_retain is not configured, abort");
return;
}
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Event::clean_events(&mut conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table")

7
src/api/core/sends.rs

@ -8,7 +8,8 @@ use rocket::fs::NamedFile;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::Mutex;
// use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
@ -40,9 +41,9 @@ pub fn routes() -> Vec<rocket::Route> {
]
}
pub async fn purge_sends(pool: Arc<Mutex<DbPool>>) {
pub async fn purge_sends(pool: Arc<RwLock<DbPool>>) {
debug!("Purging sends");
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
Send::purge(&mut conn).await;
} else {
error!("Failed to get DB connection while purging sends")

6
src/api/core/two_factor/duo_oidc.rs

@ -5,7 +5,7 @@ use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256};
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
@ -346,9 +346,9 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
}
// Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: Arc<Mutex<DbPool>>) {
pub async fn purge_duo_contexts(pool: Arc<RwLock<DbPool>>) {
debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.lock().await.get().await {
if let Ok(mut conn) = pool.read().await.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications")

6
src/api/core/two_factor/mod.rs

@ -5,7 +5,7 @@ use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
use crate::{
api::{
@ -247,14 +247,14 @@ pub async fn enforce_2fa_policy_for_org(
Ok(())
}
pub async fn send_incomplete_2fa_notifications(pool: Arc<Mutex<DbPool>>) {
pub async fn send_incomplete_2fa_notifications(pool: Arc<RwLock<DbPool>>) {
debug!("Sending notifications for incomplete 2FA logins");
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return;
}
let mut conn = match pool.lock().await.get().await {
let mut conn = match pool.read().await.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");

6
src/db/mod.rs

@ -12,7 +12,7 @@ use rocket::{
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore},
time::timeout,
};
@ -417,8 +417,8 @@ impl<'r> FromRequest<'r> for DbConn {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<Arc<Mutex<DbPool>>>() {
Some(p) => match p.lock().await.get().await {
match request.rocket().state::<Arc<RwLock<DbPool>>>() {
Some(p) => match p.read().await.get().await {
Ok(dbconn) => Outcome::Success(dbconn),
_ => Outcome::Error((Status::ServiceUnavailable, ())),
},

10
src/main.rs

@ -39,7 +39,7 @@ use tokio::{
fs::File,
io::{AsyncBufReadExt, BufReader},
signal::unix::SignalKind,
sync::Mutex,
sync::RwLock,
};
#[macro_use]
@ -83,10 +83,10 @@ async fn main() -> Result<(), Error> {
create_dir(&CONFIG.sends_folder(), "sends folder");
create_dir(&CONFIG.attachments_folder(), "attachments folder");
let pool = Arc::new(Mutex::new(create_db_pool().await));
let pool = Arc::new(RwLock::new(create_db_pool().await));
schedule_jobs(Arc::clone(&pool));
{
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.lock().await.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.read().await.get().await.unwrap()).await.unwrap();
}
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
@ -561,7 +561,7 @@ async fn create_db_pool() -> db::DbPool {
}
}
async fn launch_rocket(pool: Arc<Mutex<db::DbPool>>, extra_debug: bool) -> Result<(), Error> {
async fn launch_rocket(pool: Arc<RwLock<db::DbPool>>, extra_debug: bool) -> Result<(), Error> {
let basepath = &CONFIG.domain_path();
let mut config = rocket::Config::from(rocket::Config::figment());
@ -624,7 +624,7 @@ async fn launch_rocket(pool: Arc<Mutex<db::DbPool>>, extra_debug: bool) -> Resul
Ok(())
}
fn schedule_jobs(pool: Arc<Mutex<db::DbPool>>) {
fn schedule_jobs(pool: Arc<RwLock<db::DbPool>>) {
if CONFIG.job_poll_interval_ms() == 0 {
info!("Job scheduler disabled.");
return;

Loading…
Cancel
Save