mod query_logger; use std::{ sync::{Arc, OnceLock}, time::Duration, }; use diesel::{ connection::SimpleConnection, r2d2::{CustomizeConnection, Pool, PooledConnection}, Connection, RunQueryDsl, }; use rocket::{ http::Status, request::{FromRequest, Outcome}, Request, }; use tokio::{ sync::{Mutex, OwnedSemaphorePermit, Semaphore}, time::timeout, }; use crate::{ error::{Error, MapResult}, CONFIG, }; // These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools // A wrapper around spawn_blocking that propagates panics to the calling code. pub async fn run_blocking(job: F) -> R where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { match tokio::task::spawn_blocking(job).await { Ok(ret) => ret, Err(e) => match e.try_into_panic() { Ok(panic) => std::panic::resume_unwind(panic), Err(_) => unreachable!("spawn_blocking tasks are never cancelled"), }, } } // This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported #[derive(diesel::MultiConnection)] pub enum DbConnInner { #[cfg(mysql)] Mysql(diesel::mysql::MysqlConnection), #[cfg(postgresql)] Postgresql(diesel::pg::PgConnection), #[cfg(sqlite)] Sqlite(diesel::sqlite::SqliteConnection), } /// Custom connection manager that implements manual connection establishment pub struct DbConnManager { database_url: String, } impl DbConnManager { pub fn new(database_url: &str) -> Self { Self { database_url: database_url.to_string(), } } fn establish_connection(&self) -> Result { match DbConnType::from_url(&self.database_url) { #[cfg(mysql)] Ok(DbConnType::Mysql) => { let conn = diesel::mysql::MysqlConnection::establish(&self.database_url)?; Ok(DbConnInner::Mysql(conn)) } #[cfg(postgresql)] Ok(DbConnType::Postgresql) => { let conn = diesel::pg::PgConnection::establish(&self.database_url)?; Ok(DbConnInner::Postgresql(conn)) } #[cfg(sqlite)] Ok(DbConnType::Sqlite) => { let conn = diesel::sqlite::SqliteConnection::establish(&self.database_url)?; Ok(DbConnInner::Sqlite(conn)) } Err(e) => Err(diesel::r2d2::Error::ConnectionError(diesel::ConnectionError::InvalidConnectionUrl( format!("Unable to estabilsh a connection: {e:?}"), ))), } } } impl diesel::r2d2::ManageConnection for DbConnManager { type Connection = DbConnInner; type Error = diesel::r2d2::Error; fn connect(&self) -> Result { self.establish_connection() } fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { use diesel::r2d2::R2D2Connection; conn.ping().map_err(diesel::r2d2::Error::QueryError) } fn has_broken(&self, conn: &mut Self::Connection) -> bool { use diesel::r2d2::R2D2Connection; conn.is_broken() } } #[derive(Eq, PartialEq)] pub enum DbConnType { #[cfg(mysql)] Mysql, #[cfg(postgresql)] Postgresql, #[cfg(sqlite)] Sqlite, } pub static ACTIVE_DB_TYPE: OnceLock = OnceLock::new(); pub struct DbConn { conn: Arc>>>, permit: Option, } #[derive(Debug)] pub struct DbConnOptions { pub init_stmts: String, } impl CustomizeConnection for DbConnOptions { fn on_acquire(&self, conn: &mut DbConnInner) -> Result<(), diesel::r2d2::Error> { if !self.init_stmts.is_empty() { conn.batch_execute(&self.init_stmts).map_err(diesel::r2d2::Error::QueryError)?; } Ok(()) } } #[derive(Clone)] pub struct DbPool { // This is an 'Option' so that we can drop the pool in a 'spawn_blocking'. pool: Option>, semaphore: Arc, } impl Drop for DbConn { fn drop(&mut self) { let conn = Arc::clone(&self.conn); let permit = self.permit.take(); // Since connection can't be on the stack in an async fn during an // await, we have to spawn a new blocking-safe thread... tokio::task::spawn_blocking(move || { // And then re-enter the runtime to wait on the async mutex, but in a blocking fashion. let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned()); if let Some(conn) = conn.take() { drop(conn); } // Drop permit after the connection is dropped drop(permit); }); } } impl Drop for DbPool { fn drop(&mut self) { let pool = self.pool.take(); // Only use spawn_blocking if the Tokio runtime is still available // Otherwise the pool will be dropped on the current thread if let Ok(handle) = tokio::runtime::Handle::try_current() { handle.spawn_blocking(move || drop(pool)); } } } impl DbPool { // For the given database URL, guess its type, run migrations, create pool, and return it pub fn from_config() -> Result { let db_url = CONFIG.database_url(); let conn_type = DbConnType::from_url(&db_url)?; // Only set the default instrumentation if the log level is specifically set to either warn, info or debug if log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Warn) || log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Info) || log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Debug) { drop(diesel::connection::set_default_instrumentation(query_logger::simple_logger)); } match conn_type { #[cfg(mysql)] DbConnType::Mysql => { mysql_migrations::run_migrations(&db_url)?; } #[cfg(postgresql)] DbConnType::Postgresql => { postgresql_migrations::run_migrations(&db_url)?; } #[cfg(sqlite)] DbConnType::Sqlite => { sqlite_migrations::run_migrations(&db_url)?; } } let max_conns = CONFIG.database_max_conns(); let manager = DbConnManager::new(&db_url); let pool = Pool::builder() .max_size(max_conns) .min_idle(Some(CONFIG.database_min_conns())) .idle_timeout(Some(Duration::from_secs(CONFIG.database_idle_timeout()))) .connection_timeout(Duration::from_secs(CONFIG.database_timeout())) .connection_customizer(Box::new(DbConnOptions { init_stmts: conn_type.get_init_stmts(), })) .build(manager) .map_res("Failed to create pool")?; // Set a global to determine the database more easily throughout the rest of the code if ACTIVE_DB_TYPE.set(conn_type).is_err() { error!("Tried to set the active database connection type more than once.") } Ok(DbPool { pool: Some(pool), semaphore: Arc::new(Semaphore::new(max_conns as usize)), }) } // Get a connection from the pool pub async fn get(&self) -> Result { let duration = Duration::from_secs(CONFIG.database_timeout()); let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await { Ok(p) => p.expect("Semaphore should be open"), Err(_) => { err!("Timeout waiting for database connection"); } }; let p = self.pool.as_ref().expect("DbPool.pool should always be Some()"); let pool = p.clone(); let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?; Ok(DbConn { conn: Arc::new(Mutex::new(Some(c))), permit: Some(permit), }) } } impl DbConnType { pub fn from_url(url: &str) -> Result { // Mysql if url.len() > 6 && &url[..6] == "mysql:" { #[cfg(mysql)] return Ok(DbConnType::Mysql); #[cfg(not(mysql))] err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled") // Postgresql } else if url.len() > 11 && (&url[..11] == "postgresql:" || &url[..9] == "postgres:") { #[cfg(postgresql)] return Ok(DbConnType::Postgresql); #[cfg(not(postgresql))] err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled") //Sqlite } else { #[cfg(sqlite)] return Ok(DbConnType::Sqlite); #[cfg(not(sqlite))] err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled") } } pub fn get_init_stmts(&self) -> String { let init_stmts = CONFIG.database_conn_init(); if !init_stmts.is_empty() { init_stmts } else { self.default_init_stmts() } } pub fn default_init_stmts(&self) -> String { match self { #[cfg(mysql)] Self::Mysql => String::new(), #[cfg(postgresql)] Self::Postgresql => String::new(), #[cfg(sqlite)] Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(), } } } impl DbConn { pub async fn run(&self, f: F) -> R where F: FnOnce(&mut DbConnInner) -> R + Send, R: Send + 'static, { let conn = Arc::clone(&self.conn); let mut conn = conn.lock_owned().await; let conn = conn.as_mut().expect("Internal invariant broken: self.conn is Some"); // Run blocking can't be used due to the 'static limitation, use block_in_place instead tokio::task::block_in_place(move || f(conn)) } } #[macro_export] macro_rules! db_run { ( $conn:ident: $body:block ) => { $conn.run(move |$conn| $body).await }; ( $conn:ident: $( $($db:ident),+ $body:block )+ ) => { $conn.run(move |$conn| { match $conn { $($( #[cfg($db)] pastey::paste!(&mut $crate::db::DbConnInner::[<$db:camel>](ref mut $conn)) => { $body }, )+)+} }).await }; } pub mod schema; // Reexport the models, needs to be after the macros are defined so it can access them pub mod models; /// Creates a back-up of the sqlite database /// MySQL/MariaDB and PostgreSQL are not supported. #[cfg(sqlite)] pub fn backup_sqlite() -> Result { use diesel::Connection; use std::{fs::File, io::Write}; let db_url = CONFIG.database_url(); if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) { // Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't // This way we can set a readonly flag on the opening mode without issues. let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?; let db_path = std::path::Path::new(&db_url).parent().unwrap(); let backup_file = db_path .join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S"))) .to_string_lossy() .into_owned(); match File::create(backup_file.clone()) { Ok(mut f) => { let serialized_db = conn.serialize_database_to_buffer(); f.write_all(serialized_db.as_slice()).expect("Error writing SQLite backup"); Ok(backup_file) } Err(e) => { err_silent!(format!("Unable to save SQLite backup: {e:?}")) } } } else { err_silent!("The database type is not SQLite. Backups only works for SQLite databases") } } #[cfg(not(sqlite))] pub fn backup_sqlite() -> Result { err_silent!("The database type is not SQLite. Backups only works for SQLite databases") } /// Get the SQL Server version pub async fn get_sql_server_version(conn: &DbConn) -> String { db_run! { conn: postgresql,mysql { diesel::select(diesel::dsl::sql::("version();")) .get_result::(conn) .unwrap_or_else(|_| "Unknown".to_string()) } sqlite { diesel::select(diesel::dsl::sql::("sqlite_version();")) .get_result::(conn) .unwrap_or_else(|_| "Unknown".to_string()) } } } /// Attempts to retrieve a single connection from the managed database pool. If /// no pool is currently managed, fails with an `InternalServerError` status. If /// no connections are available, fails with a `ServiceUnavailable` status. #[rocket::async_trait] impl<'r> FromRequest<'r> for DbConn { type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { match request.rocket().state::() { Some(p) => match p.get().await { Ok(dbconn) => Outcome::Success(dbconn), _ => Outcome::Error((Status::ServiceUnavailable, ())), }, None => Outcome::Error((Status::InternalServerError, ())), } } } // Embed the migrations from the migrations folder into the application // This way, the program automatically migrates the database to the latest version // https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html #[cfg(sqlite)] mod sqlite_migrations { use diesel::{Connection, RunQueryDsl}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite"); pub fn run_migrations(db_url: &str) -> Result<(), super::Error> { // Establish a connection to the sqlite database (this will create a new one, if it does // not exist, and exit if there is an error). let mut connection = diesel::sqlite::SqliteConnection::establish(db_url)?; // Run the migrations after successfully establishing a connection // Disable Foreign Key Checks during migration // Scoped to a connection. diesel::sql_query("PRAGMA foreign_keys = OFF") .execute(&mut connection) .expect("Failed to disable Foreign Key Checks during migrations"); // Turn on WAL in SQLite if crate::CONFIG.enable_db_wal() { diesel::sql_query("PRAGMA journal_mode=wal").execute(&mut connection).expect("Failed to turn on WAL"); } connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); Ok(()) } } #[cfg(mysql)] mod mysql_migrations { use diesel::{Connection, RunQueryDsl}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql"); pub fn run_migrations(db_url: &str) -> Result<(), super::Error> { // Make sure the database is up to date (create if it doesn't exist, or run the migrations) let mut connection = diesel::mysql::MysqlConnection::establish(db_url)?; // Disable Foreign Key Checks during migration // Scoped to a connection/session. diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0") .execute(&mut connection) .expect("Failed to disable Foreign Key Checks during migrations"); connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); Ok(()) } } #[cfg(postgresql)] mod postgresql_migrations { use diesel::Connection; use diesel_migrations::{EmbeddedMigrations, MigrationHarness}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql"); pub fn run_migrations(db_url: &str) -> Result<(), super::Error> { // Make sure the database is up to date (create if it doesn't exist, or run the migrations) let mut connection = diesel::pg::PgConnection::establish(db_url)?; connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations"); Ok(()) } }