You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

483 lines
16 KiB

mod query_logger;
use std::{
sync::{Arc, OnceLock},
time::Duration,
};
use diesel::{
connection::SimpleConnection,
r2d2::{CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
};
use rocket::{
http::Status,
request::{FromRequest, Outcome},
Request,
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout,
};
use crate::{
error::{Error, MapResult},
CONFIG,
};
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
},
}
}
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
#[derive(diesel::MultiConnection)]
pub enum DbConnInner {
#[cfg(mysql)]
Mysql(diesel::mysql::MysqlConnection),
#[cfg(postgresql)]
Postgresql(diesel::pg::PgConnection),
#[cfg(sqlite)]
Sqlite(diesel::sqlite::SqliteConnection),
}
/// Custom connection manager that implements manual connection establishment
pub struct DbConnManager {
database_url: String,
}
impl DbConnManager {
pub fn new(database_url: &str) -> Self {
Self {
database_url: database_url.to_string(),
}
}
fn establish_connection(&self) -> Result<DbConnInner, diesel::r2d2::Error> {
match DbConnType::from_url(&self.database_url) {
#[cfg(mysql)]
Ok(DbConnType::Mysql) => {
let conn = diesel::mysql::MysqlConnection::establish(&self.database_url)?;
Ok(DbConnInner::Mysql(conn))
}
#[cfg(postgresql)]
Ok(DbConnType::Postgresql) => {
let conn = diesel::pg::PgConnection::establish(&self.database_url)?;
Ok(DbConnInner::Postgresql(conn))
}
#[cfg(sqlite)]
Ok(DbConnType::Sqlite) => {
let conn = diesel::sqlite::SqliteConnection::establish(&self.database_url)?;
Ok(DbConnInner::Sqlite(conn))
}
Err(e) => Err(diesel::r2d2::Error::ConnectionError(diesel::ConnectionError::InvalidConnectionUrl(
format!("Unable to estabilsh a connection: {e:?}"),
))),
}
}
}
impl diesel::r2d2::ManageConnection for DbConnManager {
type Connection = DbConnInner;
type Error = diesel::r2d2::Error;
fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.establish_connection()
}
fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
use diesel::r2d2::R2D2Connection;
conn.ping().map_err(diesel::r2d2::Error::QueryError)
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
use diesel::r2d2::R2D2Connection;
conn.is_broken()
}
}
#[derive(Eq, PartialEq)]
pub enum DbConnType {
#[cfg(mysql)]
Mysql,
#[cfg(postgresql)]
Postgresql,
#[cfg(sqlite)]
Sqlite,
}
pub static ACTIVE_DB_TYPE: OnceLock<DbConnType> = OnceLock::new();
pub struct DbConn {
conn: Arc<Mutex<Option<PooledConnection<DbConnManager>>>>,
permit: Option<OwnedSemaphorePermit>,
}
#[derive(Debug)]
pub struct DbConnOptions {
pub init_stmts: String,
}
impl CustomizeConnection<DbConnInner, diesel::r2d2::Error> for DbConnOptions {
fn on_acquire(&self, conn: &mut DbConnInner) -> Result<(), diesel::r2d2::Error> {
if !self.init_stmts.is_empty() {
conn.batch_execute(&self.init_stmts).map_err(diesel::r2d2::Error::QueryError)?;
}
Ok(())
}
}
#[derive(Clone)]
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<Pool<DbConnManager>>,
semaphore: Arc<Semaphore>,
}
impl Drop for DbConn {
fn drop(&mut self) {
let conn = Arc::clone(&self.conn);
let permit = self.permit.take();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
tokio::task::spawn_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
if let Some(conn) = conn.take() {
drop(conn);
}
// Drop permit after the connection is dropped
drop(permit);
});
}
}
impl Drop for DbPool {
fn drop(&mut self) {
let pool = self.pool.take();
// Only use spawn_blocking if the Tokio runtime is still available
// Otherwise the pool will be dropped on the current thread
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn_blocking(move || drop(pool));
}
}
}
impl DbPool {
// For the given database URL, guess its type, run migrations, create pool, and return it
pub fn from_config() -> Result<Self, Error> {
let db_url = CONFIG.database_url();
let conn_type = DbConnType::from_url(&db_url)?;
// Only set the default instrumentation if the log level is specifically set to either warn, info or debug
if log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Warn)
|| log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Info)
|| log_enabled!(target: "vaultwarden::db::query_logger", log::Level::Debug)
{
drop(diesel::connection::set_default_instrumentation(query_logger::simple_logger));
}
match conn_type {
#[cfg(mysql)]
DbConnType::Mysql => {
mysql_migrations::run_migrations(&db_url)?;
}
#[cfg(postgresql)]
DbConnType::Postgresql => {
postgresql_migrations::run_migrations(&db_url)?;
}
#[cfg(sqlite)]
DbConnType::Sqlite => {
sqlite_migrations::run_migrations(&db_url)?;
}
}
let max_conns = CONFIG.database_max_conns();
let manager = DbConnManager::new(&db_url);
let pool = Pool::builder()
.max_size(max_conns)
.min_idle(Some(CONFIG.database_min_conns()))
.idle_timeout(Some(Duration::from_secs(CONFIG.database_idle_timeout())))
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
.connection_customizer(Box::new(DbConnOptions {
init_stmts: conn_type.get_init_stmts(),
}))
.build(manager)
.map_res("Failed to create pool")?;
// Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.")
}
Ok(DbPool {
pool: Some(pool),
semaphore: Arc::new(Semaphore::new(max_conns as usize)),
})
}
// Get a connection from the pool
pub async fn get(&self) -> Result<DbConn, Error> {
let duration = Duration::from_secs(CONFIG.database_timeout());
let permit = match timeout(duration, Arc::clone(&self.semaphore).acquire_owned()).await {
Ok(p) => p.expect("Semaphore should be open"),
Err(_) => {
err!("Timeout waiting for database connection");
}
};
let p = self.pool.as_ref().expect("DbPool.pool should always be Some()");
let pool = p.clone();
let c =
run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
Ok(DbConn {
conn: Arc::new(Mutex::new(Some(c))),
permit: Some(permit),
})
}
}
impl DbConnType {
pub fn from_url(url: &str) -> Result<Self, Error> {
// Mysql
if url.len() > 6 && &url[..6] == "mysql:" {
#[cfg(mysql)]
return Ok(DbConnType::Mysql);
#[cfg(not(mysql))]
err!("`DATABASE_URL` is a MySQL URL, but the 'mysql' feature is not enabled")
// Postgresql
} else if url.len() > 11 && (&url[..11] == "postgresql:" || &url[..9] == "postgres:") {
#[cfg(postgresql)]
return Ok(DbConnType::Postgresql);
#[cfg(not(postgresql))]
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
//Sqlite
} else {
#[cfg(sqlite)]
return Ok(DbConnType::Sqlite);
#[cfg(not(sqlite))]
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
}
}
pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init();
if !init_stmts.is_empty() {
init_stmts
} else {
self.default_init_stmts()
}
}
pub fn default_init_stmts(&self) -> String {
match self {
#[cfg(mysql)]
Self::Mysql => String::new(),
#[cfg(postgresql)]
Self::Postgresql => String::new(),
#[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
}
}
}
impl DbConn {
pub async fn run<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut DbConnInner) -> R + Send,
R: Send + 'static,
{
let conn = Arc::clone(&self.conn);
let mut conn = conn.lock_owned().await;
let conn = conn.as_mut().expect("Internal invariant broken: self.conn is Some");
// Run blocking can't be used due to the 'static limitation, use block_in_place instead
tokio::task::block_in_place(move || f(conn))
}
}
#[macro_export]
macro_rules! db_run {
( $conn:ident: $body:block ) => {
$conn.run(move |$conn| $body).await
};
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
$conn.run(move |$conn| {
match $conn {
$($(
#[cfg($db)]
pastey::paste!(&mut $crate::db::DbConnInner::[<$db:camel>](ref mut $conn)) => {
$body
},
)+)+}
}).await
};
}
pub mod schema;
// Reexport the models, needs to be after the macros are defined so it can access them
pub mod models;
/// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported.
#[cfg(sqlite)]
pub fn backup_sqlite() -> Result<String, Error> {
use diesel::Connection;
use std::{fs::File, io::Write};
let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) {
// Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't
// This way we can set a readonly flag on the opening mode without issues.
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?;
let db_path = std::path::Path::new(&db_url).parent().unwrap();
let backup_file = db_path
.join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S")))
.to_string_lossy()
.into_owned();
match File::create(backup_file.clone()) {
Ok(mut f) => {
let serialized_db = conn.serialize_database_to_buffer();
f.write_all(serialized_db.as_slice()).expect("Error writing SQLite backup");
Ok(backup_file)
}
Err(e) => {
err_silent!(format!("Unable to save SQLite backup: {e:?}"))
}
}
} else {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
}
#[cfg(not(sqlite))]
pub fn backup_sqlite() -> Result<String, Error> {
err_silent!("The database type is not SQLite. Backups only works for SQLite databases")
}
/// Get the SQL Server version
pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! { conn:
postgresql,mysql {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
}
sqlite {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
}
}
}
/// Attempts to retrieve a single connection from the managed database pool. If
/// no pool is currently managed, fails with an `InternalServerError` status. If
/// no connections are available, fails with a `ServiceUnavailable` status.
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DbConn {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<DbPool>() {
Some(p) => match p.get().await {
Ok(dbconn) => Outcome::Success(dbconn),
_ => Outcome::Error((Status::ServiceUnavailable, ())),
},
None => Outcome::Error((Status::InternalServerError, ())),
}
}
}
// Embed the migrations from the migrations folder into the application
// This way, the program automatically migrates the database to the latest version
// https://docs.rs/diesel_migrations/*/diesel_migrations/macro.embed_migrations.html
#[cfg(sqlite)]
mod sqlite_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
pub fn run_migrations(db_url: &str) -> Result<(), super::Error> {
// Establish a connection to the sqlite database (this will create a new one, if it does
// not exist, and exit if there is an error).
let mut connection = diesel::sqlite::SqliteConnection::establish(db_url)?;
// Run the migrations after successfully establishing a connection
// Disable Foreign Key Checks during migration
// Scoped to a connection.
diesel::sql_query("PRAGMA foreign_keys = OFF")
.execute(&mut connection)
.expect("Failed to disable Foreign Key Checks during migrations");
// Turn on WAL in SQLite
if crate::CONFIG.enable_db_wal() {
diesel::sql_query("PRAGMA journal_mode=wal").execute(&mut connection).expect("Failed to turn on WAL");
}
connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations");
Ok(())
}
}
#[cfg(mysql)]
mod mysql_migrations {
use diesel::{Connection, RunQueryDsl};
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/mysql");
pub fn run_migrations(db_url: &str) -> Result<(), super::Error> {
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::mysql::MysqlConnection::establish(db_url)?;
// Disable Foreign Key Checks during migration
// Scoped to a connection/session.
diesel::sql_query("SET FOREIGN_KEY_CHECKS = 0")
.execute(&mut connection)
.expect("Failed to disable Foreign Key Checks during migrations");
connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations");
Ok(())
}
}
#[cfg(postgresql)]
mod postgresql_migrations {
use diesel::Connection;
use diesel_migrations::{EmbeddedMigrations, MigrationHarness};
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgresql");
pub fn run_migrations(db_url: &str) -> Result<(), super::Error> {
// Make sure the database is up to date (create if it doesn't exist, or run the migrations)
let mut connection = diesel::pg::PgConnection::establish(db_url)?;
connection.run_pending_migrations(MIGRATIONS).expect("Error running migrations");
Ok(())
}
}