diff --git a/src/db/mod.rs b/src/db/mod.rs index 7d87877d..367961fa 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -7,7 +7,7 @@ use std::{ use diesel::{ connection::SimpleConnection, - r2d2::{ConnectionManager, CustomizeConnection, Pool, PooledConnection}, + r2d2::{CustomizeConnection, Pool, PooledConnection}, Connection, RunQueryDsl, }; @@ -54,6 +54,64 @@ pub enum DbConnInner { Sqlite(diesel::sqlite::SqliteConnection), } +/// Custom connection manager that implements manual connection establishment +pub struct DbConnManager { + database_url: String, +} + +impl DbConnManager { + pub fn new(database_url: &str) -> Self { + Self { + database_url: database_url.to_string(), + } + } + + fn establish_connection(&self) -> Result { + let url = &self.database_url; + + match DbConnType::from_url(url) { + #[cfg(mysql)] + Ok(DbConnType::Mysql) => { + let conn = diesel::mysql::MysqlConnection::establish(url)?; + Ok(DbConnInner::Mysql(conn)) + } + #[cfg(postgresql)] + Ok(DbConnType::Postgresql) => { + let conn = diesel::pg::PgConnection::establish(url)?; + Ok(DbConnInner::Postgresql(conn)) + } + #[cfg(sqlite)] + Ok(DbConnType::Sqlite) => { + let conn = diesel::sqlite::SqliteConnection::establish(url)?; + Ok(DbConnInner::Sqlite(conn)) + } + + Err(e) => Err(diesel::r2d2::Error::ConnectionError(diesel::ConnectionError::InvalidConnectionUrl( + format!("Unable to estabilsh a connection: {e:?}"), + ))), + } + } +} + +impl diesel::r2d2::ManageConnection for DbConnManager { + type Connection = DbConnInner; + type Error = diesel::r2d2::Error; + + fn connect(&self) -> Result { + self.establish_connection() + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + use diesel::r2d2::R2D2Connection; + conn.ping().map_err(diesel::r2d2::Error::QueryError) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + use diesel::r2d2::R2D2Connection; + conn.is_broken() + } +} + #[derive(Eq, PartialEq)] pub enum DbConnType { #[cfg(mysql)] @@ -67,7 +125,7 @@ pub enum DbConnType { pub static ACTIVE_DB_TYPE: OnceLock = OnceLock::new(); pub struct DbConn { - conn: Arc>>>>, + conn: Arc>>>, permit: Option, } @@ -88,7 +146,7 @@ impl CustomizeConnection for DbConnOptions { #[derive(Clone)] pub struct DbPool { // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'. - pool: Option>>, + pool: Option>, semaphore: Arc, } @@ -154,7 +212,7 @@ impl DbPool { } let max_conns = CONFIG.database_max_conns(); - let manager = ConnectionManager::::new(&db_url); + let manager = DbConnManager::new(&db_url); let pool = Pool::builder() .max_size(max_conns) .min_idle(Some(CONFIG.database_min_conns())) diff --git a/src/db/query_logger.rs b/src/db/query_logger.rs index a407fd0d..0a207918 100644 --- a/src/db/query_logger.rs +++ b/src/db/query_logger.rs @@ -7,26 +7,23 @@ thread_local! { pub fn simple_logger() -> Option> { Some(Box::new(|event: InstrumentationEvent<'_>| match event { - // TODO: Figure out where the invalid connection errors are coming from - // There seem to be some invalid errors when connecting to a SQLite database - // Until the cause of this is found and resolved, disable the Connection logging - // InstrumentationEvent::StartEstablishConnection { - // url, - // .. - // } => { - // debug!("Establishing connection: {url}") - // } - // InstrumentationEvent::FinishEstablishConnection { - // url, - // error, - // .. - // } => { - // if let Some(e) = error { - // error!("Error during establishing a connection with {url}: {e:?}") - // } else { - // debug!("Connection established: {url}") - // } - // } + InstrumentationEvent::StartEstablishConnection { + url, + .. + } => { + debug!("Establishing connection: {url}") + } + InstrumentationEvent::FinishEstablishConnection { + url, + error, + .. + } => { + if let Some(e) = error { + error!("Error during establishing a connection with {url}: {e:?}") + } else { + debug!("Connection established: {url}") + } + } InstrumentationEvent::StartQuery { query, .. diff --git a/src/error.rs b/src/error.rs index 06ebf3aa..37316e57 100644 --- a/src/error.rs +++ b/src/error.rs @@ -38,7 +38,8 @@ macro_rules! make_error { }; } -use diesel::r2d2::PoolError as R2d2Err; +use diesel::r2d2::Error as R2d2Err; +use diesel::r2d2::PoolError as R2d2PoolErr; use diesel::result::Error as DieselErr; use diesel::ConnectionError as DieselConErr; use handlebars::RenderError as HbErr; @@ -78,12 +79,13 @@ make_error! { CustomHttpClient(CustomHttpClientError): _has_source, _api_error, // Used for special return values, like 2FA errors - Json(Value): _no_source, _serialize, - Db(DieselErr): _has_source, _api_error, - R2d2(R2d2Err): _has_source, _api_error, - Serde(SerdeErr): _has_source, _api_error, - JWt(JwtErr): _has_source, _api_error, - Handlebars(HbErr): _has_source, _api_error, + Json(Value): _no_source, _serialize, + Db(DieselErr): _has_source, _api_error, + R2d2(R2d2Err): _has_source, _api_error, + R2d2Pool(R2d2PoolErr): _has_source, _api_error, + Serde(SerdeErr): _has_source, _api_error, + JWt(JwtErr): _has_source, _api_error, + Handlebars(HbErr): _has_source, _api_error, Io(IoErr): _has_source, _api_error, Time(TimeErr): _has_source, _api_error,