Browse Source

use arc for pool of db

pull/5037/head
Dmitry Ulyanov 6 months ago
parent
commit
e8b7a3fa63
  1. 7
      src/api/core/accounts.rs
  2. 6
      src/api/core/ciphers.rs
  3. 11
      src/api/core/emergency_access.rs
  4. 7
      src/api/core/events.rs
  5. 6
      src/api/core/sends.rs
  6. 7
      src/api/core/two_factor/duo_oidc.rs
  7. 7
      src/api/core/two_factor/mod.rs
  8. 6
      src/main.rs

7
src/api/core/accounts.rs

@ -1,7 +1,10 @@
use std::sync::Arc;
use crate::db::DbPool;
use chrono::{SecondsFormat, Utc};
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::{
api::{
@ -1282,9 +1285,9 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult {
})))
}
pub async fn purge_auth_requests(pool: DbPool) {
pub async fn purge_auth_requests(pool: Arc<Mutex<DbPool>>) {
debug!("Purging auth requests");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
AuthRequest::purge_expired_auth_requests(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")

6
src/api/core/ciphers.rs

@ -1,4 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use chrono::{NaiveDateTime, Utc};
use num_traits::ToPrimitive;
@ -9,6 +10,7 @@ use rocket::{
Route,
};
use serde_json::Value;
use tokio::sync::Mutex;
use crate::util::NumberOrString;
use crate::{
@ -88,9 +90,9 @@ pub fn routes() -> Vec<Route> {
]
}
pub async fn purge_trashed_ciphers(pool: DbPool) {
pub async fn purge_trashed_ciphers(pool: Arc<Mutex<DbPool>>) {
debug!("Purging trashed ciphers");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
Cipher::purge_trash(&mut conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")

11
src/api/core/emergency_access.rs

@ -1,6 +1,9 @@
use std::sync::Arc;
use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::Mutex;
use crate::{
api::{
@ -729,13 +732,13 @@ fn check_emergency_access_enabled() -> EmptyResult {
Ok(())
}
pub async fn emergency_request_timeout_job(pool: DbPool) {
pub async fn emergency_request_timeout_job(pool: Arc<Mutex<DbPool>>) {
debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() {
@ -784,13 +787,13 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
}
}
pub async fn emergency_notification_reminder_job(pool: DbPool) {
pub async fn emergency_notification_reminder_job(pool: Arc<Mutex<DbPool>>) {
debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries_initiated(&mut conn).await;
if emergency_access_list.is_empty() {

7
src/api/core/events.rs

@ -1,8 +1,9 @@
use std::net::IpAddr;
use std::{net::IpAddr, sync::Arc};
use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route};
use serde_json::Value;
use tokio::sync::Mutex;
use crate::{
api::{EmptyResult, JsonResult},
@ -320,14 +321,14 @@ async fn _log_event(
event.save(conn).await.unwrap_or(());
}
pub async fn event_cleanup_job(pool: DbPool) {
pub async fn event_cleanup_job(pool: Arc<Mutex<DbPool>>) {
debug!("Start events cleanup job");
if CONFIG.events_days_retain().is_none() {
debug!("events_days_retain is not configured, abort");
return;
}
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
Event::clean_events(&mut conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table")

6
src/api/core/sends.rs

@ -1,4 +1,5 @@
use std::path::Path;
use std::sync::Arc;
use chrono::{DateTime, TimeDelta, Utc};
use num_traits::ToPrimitive;
@ -7,6 +8,7 @@ use rocket::fs::NamedFile;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::{
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
@ -38,9 +40,9 @@ pub fn routes() -> Vec<rocket::Route> {
]
}
pub async fn purge_sends(pool: DbPool) {
pub async fn purge_sends(pool: Arc<Mutex<DbPool>>) {
debug!("Purging sends");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
Send::purge(&mut conn).await;
} else {
error!("Failed to get DB connection while purging sends")

7
src/api/core/two_factor/duo_oidc.rs

@ -4,7 +4,8 @@ use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256};
use serde::Serialize;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
@ -345,9 +346,9 @@ async fn extract_context(state: &str, conn: &mut DbConn) -> Option<DuoAuthContex
}
// Task to clean up expired Duo authentication contexts that may have accumulated in the database.
pub async fn purge_duo_contexts(pool: DbPool) {
pub async fn purge_duo_contexts(pool: Arc<Mutex<DbPool>>) {
debug!("Purging Duo authentication contexts");
if let Ok(mut conn) = pool.get().await {
if let Ok(mut conn) = pool.lock().await.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&mut conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications")

7
src/api/core/two_factor/mod.rs

@ -1,8 +1,11 @@
use std::sync::Arc;
use chrono::{TimeDelta, Utc};
use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::{
api::{
@ -244,14 +247,14 @@ pub async fn enforce_2fa_policy_for_org(
Ok(())
}
pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
pub async fn send_incomplete_2fa_notifications(pool: Arc<Mutex<DbPool>>) {
debug!("Sending notifications for incomplete 2FA logins");
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return;
}
let mut conn = match pool.get().await {
let mut conn = match pool.lock().await.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");

6
src/main.rs

@ -39,6 +39,7 @@ use tokio::{
fs::File,
io::{AsyncBufReadExt, BufReader},
signal::unix::SignalKind,
sync::Mutex,
};
#[macro_use]
@ -83,7 +84,8 @@ async fn main() -> Result<(), Error> {
create_dir(&CONFIG.attachments_folder(), "attachments folder");
let pool = create_db_pool().await;
schedule_jobs(pool.clone());
let poolArc = Arc::new(Mutex::new(pool.clone()));
schedule_jobs(poolArc.clone());
db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap();
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
@ -621,7 +623,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
Ok(())
}
fn schedule_jobs(pool: db::DbPool) {
fn schedule_jobs(pool: Arc<Mutex<db::DbPool>>) {
if CONFIG.job_poll_interval_ms() == 0 {
info!("Job scheduler disabled.");
return;

Loading…
Cancel
Save