diff --git a/Cargo.lock b/Cargo.lock index a4697493..14f0bbbf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3875,6 +3875,20 @@ dependencies = [ "yansi", ] +[[package]] +name = "prometheus" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +dependencies = [ + "cfg-if", + "fnv", + "lazy_static", + "memchr", + "parking_lot", + "thiserror 1.0.69", +] + [[package]] name = "psl-types" version = "2.0.11" @@ -5824,12 +5838,14 @@ dependencies = [ "mini-moka", "num-derive", "num-traits", + "once_cell", "opendal", "openidconnect", "openssl", "pastey 0.2.1", "percent-encoding", "pico-args", + "prometheus", "rand 0.9.2", "regex", "reqsign", diff --git a/Cargo.toml b/Cargo.toml index 9d54590e..62def08c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ vendored_openssl = ["openssl/vendored"] # Enable MiMalloc memory allocator to replace the default malloc # This can improve performance for Alpine builds enable_mimalloc = ["dep:mimalloc"] +# Enable Prometheus metrics endpoint +enable_metrics = ["dep:prometheus"] s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"] # OIDC specific features @@ -77,6 +79,9 @@ rmpv = "1.3.1" # MessagePack library # Concurrent HashMap used for WebSocket messaging and favicons dashmap = "6.1.0" +# Lazy static initialization +once_cell = "1.20.2" + # Async futures futures = "0.3.31" tokio = { version = "1.49.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] } @@ -182,6 +187,9 @@ semver = "1.0.27" # Mainly used for the musl builds, since the default musl malloc is very slow mimalloc = { version = "0.1.48", features = ["secure"], default-features = false, optional = true } +# Prometheus metrics +prometheus = { version = "0.13.1", default-features = false, optional = true } + which = "8.0.0" # Argon2 library with support for the PHC format diff --git a/README.md b/README.md index c84a9c40..3835968f 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ A nearly complete implementation of the Bitwarden Client API is provided, includ [Duo](https://bitwarden.com/help/setup-two-step-login-duo/) * [Emergency Access](https://bitwarden.com/help/emergency-access/) * [Vaultwarden Admin Backend](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-admin-page) + * [Prometheus Metrics](https://github.com/dani-garcia/vaultwarden/wiki/Metrics) - Optional monitoring and observability with secure endpoint * [Modified Web Vault client](https://github.com/dani-garcia/bw_web_builds) (Bundled within our containers)
diff --git a/docker/Dockerfile.alpine b/docker/Dockerfile.alpine index 95aae642..0658a9e0 100644 --- a/docker/Dockerfile.alpine +++ b/docker/Dockerfile.alpine @@ -82,7 +82,7 @@ ARG CARGO_PROFILE=release # Configure the DB ARG as late as possible to not invalidate the cached layers above # Enable MiMalloc to improve performance on Alpine builds -ARG DB=sqlite,mysql,postgresql,enable_mimalloc +ARG DB=sqlite,mysql,postgresql,enable_mimalloc,enable_metrics # Builds your dependencies and removes the # dummy project, except the target folder diff --git a/docker/Dockerfile.debian b/docker/Dockerfile.debian index 113304b8..bbc952aa 100644 --- a/docker/Dockerfile.debian +++ b/docker/Dockerfile.debian @@ -116,7 +116,7 @@ COPY ./macros ./macros ARG CARGO_PROFILE=release # Configure the DB ARG as late as possible to not invalidate the cached layers above -ARG DB=sqlite,mysql,postgresql +ARG DB=sqlite,mysql,postgresql,enable_metrics # Builds your dependencies and removes the # dummy project, except the target folder diff --git a/docker/Dockerfile.j2 b/docker/Dockerfile.j2 index f745780e..2f55843d 100644 --- a/docker/Dockerfile.j2 +++ b/docker/Dockerfile.j2 @@ -144,10 +144,10 @@ ARG CARGO_PROFILE=release # Configure the DB ARG as late as possible to not invalidate the cached layers above {% if base == "debian" %} -ARG DB=sqlite,mysql,postgresql +ARG DB=sqlite,mysql,postgresql,enable_metrics {% elif base == "alpine" %} # Enable MiMalloc to improve performance on Alpine builds -ARG DB=sqlite,mysql,postgresql,enable_mimalloc +ARG DB=sqlite,mysql,postgresql,enable_mimalloc,enable_metrics {% endif %} # Builds your dependencies and removes the diff --git a/src/api/identity.rs b/src/api/identity.rs index 9eaa6b36..cca87ea0 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -32,7 +32,7 @@ use crate::{ error::MapResult, mail, sso, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, - util, CONFIG, + util, CONFIG, metrics, }; pub fn routes() -> Vec { @@ -60,7 +60,8 @@ async fn login( let mut user_id: Option = None; - let login_result = match data.grant_type.as_ref() { + let auth_method = data.grant_type.clone(); + let login_result = match auth_method.as_ref() { "refresh_token" => { _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?; _refresh_login(data, &conn, &client_header.ip).await @@ -104,6 +105,10 @@ async fn login( t => err!("Invalid type", t), }; + // Record authentication metrics + let auth_status = if login_result.is_ok() { "success" } else { "failed" }; + metrics::increment_auth_attempts(&auth_method, auth_status); + if let Some(user_id) = user_id { match &login_result { Ok(_) => { diff --git a/src/api/metrics.rs b/src/api/metrics.rs new file mode 100644 index 00000000..a244f053 --- /dev/null +++ b/src/api/metrics.rs @@ -0,0 +1,98 @@ +use rocket::{ + request::{FromRequest, Outcome, Request}, + response::content::RawText, + Route, +}; + +use crate::{auth::ClientIp, db::DbConn, CONFIG}; + +pub fn routes() -> Vec { + if CONFIG.enable_metrics() { + routes![get_metrics] + } else { + Vec::new() + } +} + +pub struct MetricsToken { + _ip: ClientIp, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for MetricsToken { + type Error = &'static str; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let ip = match ClientIp::from_request(request).await { + Outcome::Success(ip) => ip, + _ => err_handler!("Error getting Client IP"), + }; + + let Some(configured_token) = CONFIG.metrics_token() else { + return Outcome::Success(Self { _ip: ip }); + }; + + let provided_token = request + .headers() + .get_one("Authorization") + .and_then(|auth| auth.strip_prefix("Bearer ")) + .or_else(|| request.query_value::<&str>("token").and_then(|result| result.ok())); + + match provided_token { + Some(token) => { + if validate_metrics_token(token, &configured_token) { + Outcome::Success(Self { _ip: ip }) + } else { + err_handler!("Invalid metrics token") + } + } + None => err_handler!("Metrics token required"), + } + } +} + +fn validate_metrics_token(provided: &str, configured: &str) -> bool { + if configured.starts_with("$argon2") { + use argon2::password_hash::PasswordVerifier; + match argon2::password_hash::PasswordHash::new(configured) { + Ok(hash) => argon2::Argon2::default().verify_password(provided.trim().as_bytes(), &hash).is_ok(), + Err(e) => { + error!("Invalid Argon2 PHC in METRICS_TOKEN: {e}"); + false + } + } + } else { + crate::crypto::ct_eq(configured.trim(), provided.trim()) + } +} + +/// Prometheus metrics endpoint +#[get("/")] +async fn get_metrics(_token: MetricsToken, mut conn: DbConn) -> Result, crate::error::Error> { + if let Err(e) = crate::metrics::update_business_metrics(&mut conn).await { + err!("Failed to update business metrics", e.to_string()); + } + + match crate::metrics::gather_metrics() { + Ok(metrics) => Ok(RawText(metrics)), + Err(e) => err!("Failed to gather metrics", e.to_string()), + } +} + +/// Health check endpoint that also updates some basic metrics +#[cfg(feature = "enable_metrics")] +pub async fn update_health_metrics(_conn: &mut DbConn) { + // Update basic system metrics + use std::time::SystemTime; + static START_TIME: std::sync::OnceLock = std::sync::OnceLock::new(); + let start_time = *START_TIME.get_or_init(SystemTime::now); + + crate::metrics::update_uptime(start_time); + + // Update database connection metrics + // Note: This is a simplified version - in production you'd want to get actual pool stats + crate::metrics::update_db_connections("main", 1, 0); +} + +#[cfg(not(feature = "enable_metrics"))] +pub async fn update_health_metrics(_conn: &mut DbConn) {} diff --git a/src/api/middleware.rs b/src/api/middleware.rs new file mode 100644 index 00000000..4e43c78b --- /dev/null +++ b/src/api/middleware.rs @@ -0,0 +1,176 @@ +/// Metrics middleware for automatic HTTP request instrumentation +use rocket::{ + fairing::{Fairing, Info, Kind}, + Data, Request, Response, +}; +use std::time::Instant; + +pub struct MetricsFairing; + +#[rocket::async_trait] +impl Fairing for MetricsFairing { + fn info(&self) -> Info { + Info { + name: "Metrics Collection", + kind: Kind::Request | Kind::Response, + } + } + + async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) { + req.local_cache(|| RequestTimer { + start_time: Instant::now(), + }); + } + + async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { + let timer = req.local_cache(|| RequestTimer { + start_time: Instant::now(), + }); + let duration = timer.start_time.elapsed(); + let method = req.method().as_str(); + let path = normalize_path(req.uri().path().as_str()); + let status = res.status().code; + + // Record metrics + crate::metrics::increment_http_requests(method, &path, status); + crate::metrics::observe_http_request_duration(method, &path, duration.as_secs_f64()); + } +} + +struct RequestTimer { + start_time: Instant, +} + +/// Normalize paths to avoid high cardinality metrics +/// Convert dynamic segments to static labels +fn normalize_path(path: &str) -> String { + let segments: Vec<&str> = path.split('/').collect(); + let mut normalized = Vec::new(); + + for segment in segments { + if segment.is_empty() { + continue; + } + + let normalized_segment = if is_uuid(segment) { + "{id}" + } else if is_hex_hash(segment) { + "{hash}" + } else if segment.chars().all(|c| c.is_ascii_digit()) { + "{number}" + } else { + segment + }; + + normalized.push(normalized_segment); + } + + if normalized.is_empty() { + "/".to_string() + } else { + format!("/{}", normalized.join("/")) + } +} + +/// Check if a string is a hex hash (32+ hex chars, typical for SHA256, MD5, etc) +fn is_hex_hash(s: &str) -> bool { + s.len() >= 32 && s.chars().all(|c| c.is_ascii_hexdigit()) +} + +/// Check if a string looks like a UUID +fn is_uuid(s: &str) -> bool { + s.len() == 36 + && s.chars().enumerate().all(|(i, c)| match i { + 8 | 13 | 18 | 23 => c == '-', + _ => c.is_ascii_hexdigit(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_path_preserves_static_routes() { + assert_eq!(normalize_path("/api/accounts"), "/api/accounts"); + assert_eq!(normalize_path("/api/sync"), "/api/sync"); + assert_eq!(normalize_path("/icons"), "/icons"); + } + + #[test] + fn test_normalize_path_replaces_uuid() { + let uuid = "12345678-1234-5678-9012-123456789012"; + assert_eq!( + normalize_path(&format!("/api/accounts/{uuid}")), + "/api/accounts/{id}" + ); + assert_eq!( + normalize_path(&format!("/ciphers/{uuid}")), + "/ciphers/{id}" + ); + } + + #[test] + fn test_normalize_path_replaces_sha256_hash() { + // SHA256 hashes are 64 hex characters + let sha256 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; + assert_eq!( + normalize_path(&format!("/attachments/{sha256}")), + "/attachments/{hash}" + ); + } + + #[test] + fn test_normalize_path_does_not_replace_short_hex() { + // Only consider 32+ char hex strings as hashes + assert_eq!(normalize_path("/api/hex123"), "/api/hex123"); + assert_eq!(normalize_path("/test/abc"), "/test/abc"); + assert_eq!(normalize_path("/api/abcdef1234567890"), "/api/abcdef1234567890"); // 16 chars + assert_eq!(normalize_path("/files/0123456789abcdef"), "/files/0123456789abcdef"); // 16 chars + } + + #[test] + fn test_normalize_path_replaces_numbers() { + assert_eq!(normalize_path("/api/organizations/123"), "/api/organizations/{number}"); + assert_eq!(normalize_path("/users/456/profile"), "/users/{number}/profile"); + } + + #[test] + fn test_normalize_path_root() { + assert_eq!(normalize_path("/"), "/"); + } + + #[test] + fn test_normalize_path_empty_segments() { + assert_eq!(normalize_path("//api//accounts"), "/api/accounts"); + } + + #[test] + fn test_is_uuid_valid() { + assert!(is_uuid("12345678-1234-5678-9012-123456789012")); + assert!(is_uuid("00000000-0000-0000-0000-000000000000")); + assert!(is_uuid("ffffffff-ffff-ffff-ffff-ffffffffffff")); + } + + #[test] + fn test_is_uuid_invalid_format() { + assert!(!is_uuid("not-a-uuid")); + assert!(!is_uuid("12345678123456781234567812345678")); + assert!(!is_uuid("123")); + assert!(!is_uuid("")); + assert!(!is_uuid("12345678-1234-5678-9012-12345678901")); // Too short + assert!(!is_uuid("12345678-1234-5678-9012-1234567890123")); // Too long + } + + #[test] + fn test_is_uuid_invalid_characters() { + assert!(!is_uuid("12345678-1234-5678-9012-12345678901z")); + assert!(!is_uuid("g2345678-1234-5678-9012-123456789012")); + } + + #[test] + fn test_is_uuid_invalid_dash_positions() { + assert!(!is_uuid("12345678-1234-56789012-123456789012")); + assert!(!is_uuid("12345678-1234-5678-90121-23456789012")); + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index ecdf9408..5d40c064 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,6 +2,8 @@ mod admin; pub mod core; mod icons; mod identity; +mod metrics; +mod middleware; mod notifications; mod push; mod web; @@ -22,6 +24,8 @@ pub use crate::api::{ core::{event_cleanup_job, events_routes as core_events_routes}, icons::routes as icons_routes, identity::routes as identity_routes, + metrics::routes as metrics_routes, + middleware::MetricsFairing, notifications::routes as notifications_routes, notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}, push::{ diff --git a/src/api/web.rs b/src/api/web.rs index d1ca0db4..3c33a779 100644 --- a/src/api/web.rs +++ b/src/api/web.rs @@ -179,7 +179,9 @@ async fn attachments(cipher_id: CipherId, file_id: AttachmentId, token: String) // We use DbConn here to let the alive healthcheck also verify the database connection. use crate::db::DbConn; #[get("/alive")] -fn alive(_conn: DbConn) -> Json { +async fn alive(mut conn: DbConn) -> Json { + // Update basic health metrics if metrics are enabled + let _ = crate::api::metrics::update_health_metrics(&mut conn).await; now() } diff --git a/src/config.rs b/src/config.rs index 4fb103fa..7d0d89f2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -918,6 +918,16 @@ make_config! { /// Auto-enable 2FA (Know the risks!) |> Automatically setup email 2FA as fallback provider when needed email_2fa_auto_fallback: bool, true, def, false; }, + + /// Metrics Settings + metrics { + /// Enable metrics endpoint |> Enable Prometheus metrics endpoint at /metrics + enable_metrics: bool, true, def, false; + /// Metrics token |> Optional token to secure the /metrics endpoint. If not set, endpoint is public when enabled. + metrics_token: Pass, true, option; + /// Business metrics cache timeout |> Number of seconds to cache business metrics before refreshing from database + metrics_business_cache_seconds: u64, true, def, 300; + }, } fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { @@ -1266,6 +1276,30 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { println!("[WARNING] Secure Note size limit is increased to 100_000!"); println!("[WARNING] This could cause issues with clients. Also exports will not work on Bitwarden servers!."); } + + // Validate metrics configuration + if cfg.enable_metrics { + if let Some(ref token) = cfg.metrics_token { + if token.starts_with("$argon2") { + if let Err(e) = argon2::password_hash::PasswordHash::new(token) { + err!(format!("The configured Argon2 PHC in `METRICS_TOKEN` is invalid: '{e}'")) + } + } else if token.trim().is_empty() { + err!("`METRICS_TOKEN` cannot be empty when metrics are enabled"); + } else { + println!( + "[NOTICE] You are using a plain text `METRICS_TOKEN` which is less secure.\n\ + Please consider generating a secure Argon2 PHC string by using `vaultwarden hash`.\n" + ); + } + } else { + println!( + "[WARNING] Metrics endpoint is enabled without authentication. This may expose sensitive information." + ); + println!("[WARNING] Consider setting `METRICS_TOKEN` to secure the endpoint."); + } + } + Ok(()) } diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index b28a25cd..497e8d8b 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -135,6 +135,24 @@ use crate::db::DbConn; use crate::api::EmptyResult; use crate::error::MapResult; +#[derive(QueryableByName)] +struct CipherCount { + #[diesel(sql_type = diesel::sql_types::Integer)] + atype: i32, + #[diesel(sql_type = diesel::sql_types::BigInt)] + count: i64, +} + +#[derive(QueryableByName)] +struct CipherOrgCount { + #[diesel(sql_type = diesel::sql_types::Integer)] + atype: i32, + #[diesel(sql_type = diesel::sql_types::Text)] + organization_uuid: String, + #[diesel(sql_type = diesel::sql_types::BigInt)] + count: i64, +} + /// Database methods impl Cipher { pub async fn to_json( @@ -967,6 +985,56 @@ impl Cipher { }} } + pub async fn count_by_type_and_org(conn: &DbConn) -> std::collections::HashMap<(String, String), i64> { + use std::collections::HashMap; + db_run! { conn: { + // Count personal ciphers (organization_uuid IS NULL) + let personal_results: Vec = diesel::sql_query( + "SELECT atype, COUNT(*) as count FROM ciphers WHERE deleted_at IS NULL AND organization_uuid IS NULL GROUP BY atype" + ) + .load(conn) + .expect("Error counting personal ciphers"); + + // Count organization ciphers (organization_uuid IS NOT NULL) + let org_results: Vec = diesel::sql_query( + "SELECT atype, organization_uuid, COUNT(*) as count FROM ciphers WHERE deleted_at IS NULL AND organization_uuid IS NOT NULL GROUP BY atype, organization_uuid" + ) + .load(conn) + .expect("Error counting organization ciphers"); + + let mut counts = HashMap::new(); + for result in personal_results { + let cipher_type = match result.atype { + 1 => "login", + 2 => "note", + 3 => "card", + 4 => "identity", + _ => "unknown", + }; + counts.insert((cipher_type.to_string(), "personal".to_string()), result.count); + } + for result in org_results { + let cipher_type = match result.atype { + 1 => "login", + 2 => "note", + 3 => "card", + 4 => "identity", + _ => "unknown", + }; + counts.insert((cipher_type.to_string(), result.organization_uuid), result.count); + } + counts + }} + } + + pub async fn find_all(conn: &DbConn) -> Vec { + db_run! { conn: { + ciphers::table + .load::(conn) + .expect("Error loading ciphers") + }} + } + pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec { if CONFIG.org_groups_enabled() { db_run! { conn: { diff --git a/src/db/models/organization.rs b/src/db/models/organization.rs index 0b722ef6..035402a7 100644 --- a/src/db/models/organization.rs +++ b/src/db/models/organization.rs @@ -403,6 +403,16 @@ impl Organization { }} } + pub async fn count(conn: &DbConn) -> i64 { + db_run! { conn: { + organizations::table + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} + } + pub async fn get_all(conn: &DbConn) -> Vec { db_run! { conn: { organizations::table diff --git a/src/db/models/user.rs b/src/db/models/user.rs index e88c7296..bcafd548 100644 --- a/src/db/models/user.rs +++ b/src/db/models/user.rs @@ -409,6 +409,28 @@ impl User { None } + pub async fn count_enabled(conn: &DbConn) -> i64 { + db_run! { conn: { + users::table + .filter(users::enabled.eq(true)) + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} + } + + pub async fn count_disabled(conn: &DbConn) -> i64 { + db_run! { conn: { + users::table + .filter(users::enabled.eq(false)) + .count() + .first::(conn) + .ok() + .unwrap_or(0) + }} + } + pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option)> { db_run! { conn: { users::table diff --git a/src/main.rs b/src/main.rs index 8eef2e8c..d74ae5cc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -54,6 +54,7 @@ mod crypto; mod db; mod http_client; mod mail; +mod metrics; mod ratelimit; mod sso; mod sso_client; @@ -89,6 +90,17 @@ async fn main() -> Result<(), Error> { db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_credential_to_passkey(&pool.get().await.unwrap()).await.unwrap(); + // Initialize metrics if enabled + if CONFIG.enable_metrics() { + metrics::init_build_info(); + info!("Metrics endpoint enabled at /metrics"); + if CONFIG.metrics_token().is_some() { + info!("Metrics endpoint secured with token"); + } else { + warn!("Metrics endpoint is publicly accessible"); + } + } + let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug); launch_rocket(pool, extra_debug).await // Blocks until program termination. } @@ -567,14 +579,21 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> // If adding more paths here, consider also adding them to // crate::utils::LOGGED_ROUTES to make sure they appear in the log - let instance = rocket::custom(config) + let mut instance = rocket::custom(config) .mount([basepath, "/"].concat(), api::web_routes()) .mount([basepath, "/api"].concat(), api::core_routes()) .mount([basepath, "/admin"].concat(), api::admin_routes()) .mount([basepath, "/events"].concat(), api::core_events_routes()) .mount([basepath, "/identity"].concat(), api::identity_routes()) .mount([basepath, "/icons"].concat(), api::icons_routes()) - .mount([basepath, "/notifications"].concat(), api::notifications_routes()) + .mount([basepath, "/notifications"].concat(), api::notifications_routes()); + + // Conditionally mount metrics routes if enabled + if CONFIG.enable_metrics() { + instance = instance.mount([basepath, "/metrics"].concat(), api::metrics_routes()); + } + + let mut rocket_instance = instance .register([basepath, "/"].concat(), api::web_catchers()) .register([basepath, "/api"].concat(), api::core_catchers()) .register([basepath, "/admin"].concat(), api::admin_catchers()) @@ -583,9 +602,14 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> .manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS)) .attach(util::AppHeaders()) .attach(util::Cors()) - .attach(util::BetterLogging(extra_debug)) - .ignite() - .await?; + .attach(util::BetterLogging(extra_debug)); + + // Attach metrics fairing if metrics are enabled + if CONFIG.enable_metrics() { + rocket_instance = rocket_instance.attach(api::MetricsFairing); + } + + let instance = rocket_instance.ignite().await?; CONFIG.set_rocket_shutdown_handle(instance.shutdown()); diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 00000000..8a486dbb --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,456 @@ +use std::time::SystemTime; + +#[cfg(feature = "enable_metrics")] +use once_cell::sync::Lazy; +#[cfg(feature = "enable_metrics")] +use prometheus::{ + register_gauge_vec, register_histogram_vec, register_int_counter_vec, register_int_gauge_vec, Encoder, GaugeVec, + HistogramVec, IntCounterVec, IntGaugeVec, TextEncoder, +}; + +use crate::{db::DbConn, error::Error}; +#[cfg(feature = "enable_metrics")] +use crate::CONFIG; +#[cfg(feature = "enable_metrics")] +use std::sync::RwLock; +#[cfg(feature = "enable_metrics")] +use std::time::UNIX_EPOCH; + +// HTTP request metrics +#[cfg(feature = "enable_metrics")] +static HTTP_REQUESTS_TOTAL: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "vaultwarden_http_requests_total", + "Total number of HTTP requests processed", + &["method", "path", "status"] + ) + .unwrap() +}); + +#[cfg(feature = "enable_metrics")] +static HTTP_REQUEST_DURATION_SECONDS: Lazy = Lazy::new(|| { + register_histogram_vec!( + "vaultwarden_http_request_duration_seconds", + "HTTP request duration in seconds", + &["method", "path"], + vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0] + ) + .unwrap() +}); + +// Database metrics +#[cfg(feature = "enable_metrics")] +static DB_CONNECTIONS_ACTIVE: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_db_connections_active", "Number of active database connections", &["database"]) + .unwrap() +}); + +#[cfg(feature = "enable_metrics")] +static DB_CONNECTIONS_IDLE: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_db_connections_idle", "Number of idle database connections", &["database"]) + .unwrap() +}); + +// Authentication metrics +#[cfg(feature = "enable_metrics")] +static AUTH_ATTEMPTS_TOTAL: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "vaultwarden_auth_attempts_total", + "Total number of authentication attempts", + &["method", "status"] + ) + .unwrap() +}); + +// Business metrics +#[cfg(feature = "enable_metrics")] +static USERS_TOTAL: Lazy = + Lazy::new(|| register_int_gauge_vec!("vaultwarden_users_total", "Total number of users", &["status"]).unwrap()); + +#[cfg(feature = "enable_metrics")] +static ORGANIZATIONS_TOTAL: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_organizations_total", "Total number of organizations", &["status"]).unwrap() +}); + +#[cfg(feature = "enable_metrics")] +static VAULT_ITEMS_TOTAL: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_vault_items_total", "Total number of vault items", &["type", "organization"]) + .unwrap() +}); + +#[cfg(feature = "enable_metrics")] +static COLLECTIONS_TOTAL: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_collections_total", "Total number of collections", &["organization"]).unwrap() +}); + +// System metrics +#[cfg(feature = "enable_metrics")] +static UPTIME_SECONDS: Lazy = + Lazy::new(|| register_gauge_vec!("vaultwarden_uptime_seconds", "Uptime in seconds", &["version"]).unwrap()); + +#[cfg(feature = "enable_metrics")] +static BUILD_INFO: Lazy = Lazy::new(|| { + register_int_gauge_vec!("vaultwarden_build_info", "Build information", &["version", "revision", "branch"]).unwrap() +}); + +/// Increment HTTP request counter +#[cfg(feature = "enable_metrics")] +pub fn increment_http_requests(method: &str, path: &str, status: u16) { + HTTP_REQUESTS_TOTAL.with_label_values(&[method, path, &status.to_string()]).inc(); +} + +/// Observe HTTP request duration +#[cfg(feature = "enable_metrics")] +pub fn observe_http_request_duration(method: &str, path: &str, duration_seconds: f64) { + HTTP_REQUEST_DURATION_SECONDS.with_label_values(&[method, path]).observe(duration_seconds); +} + +/// Update database connection metrics +#[cfg(feature = "enable_metrics")] +pub fn update_db_connections(database: &str, active: i64, idle: i64) { + DB_CONNECTIONS_ACTIVE.with_label_values(&[database]).set(active); + DB_CONNECTIONS_IDLE.with_label_values(&[database]).set(idle); +} + +/// Increment authentication attempts (success/failure tracking) +/// Tracks authentication success/failure by method (password, client_credentials, SSO, etc.) +/// Called from src/api/identity.rs login() after each authentication attempt +#[cfg(feature = "enable_metrics")] +pub fn increment_auth_attempts(method: &str, status: &str) { + AUTH_ATTEMPTS_TOTAL.with_label_values(&[method, status]).inc(); +} + +/// Cached business metrics data +#[cfg(feature = "enable_metrics")] +#[derive(Clone)] +struct BusinessMetricsCache { + timestamp: u64, + users_enabled: i64, + users_disabled: i64, + organizations: i64, + vault_counts: std::collections::HashMap<(String, String), i64>, + collection_counts: std::collections::HashMap, +} + +#[cfg(feature = "enable_metrics")] +static BUSINESS_METRICS_CACHE: Lazy>> = Lazy::new(|| RwLock::new(None)); + +/// Check if business metrics cache is still valid +#[cfg(feature = "enable_metrics")] +fn is_cache_valid() -> bool { + let cache_timeout = CONFIG.metrics_business_cache_seconds(); + if let Ok(cache) = BUSINESS_METRICS_CACHE.read() { + if let Some(ref cached) = *cache { + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + return now - cached.timestamp < cache_timeout; + } + } + false +} + +/// Update cached business metrics +#[cfg(feature = "enable_metrics")] +fn update_cached_metrics(cache: BusinessMetricsCache) { + if let Ok(mut cached) = BUSINESS_METRICS_CACHE.write() { + *cached = Some(cache); + } +} + +/// Apply cached metrics to Prometheus gauges +#[cfg(feature = "enable_metrics")] +fn apply_cached_metrics(cache: &BusinessMetricsCache) { + USERS_TOTAL.with_label_values(&["enabled"]).set(cache.users_enabled); + USERS_TOTAL.with_label_values(&["disabled"]).set(cache.users_disabled); + ORGANIZATIONS_TOTAL.with_label_values(&["active"]).set(cache.organizations); + + for ((cipher_type, org_label), count) in &cache.vault_counts { + VAULT_ITEMS_TOTAL.with_label_values(&[cipher_type, org_label]).set(*count); + } + + for (org_id, count) in &cache.collection_counts { + COLLECTIONS_TOTAL.with_label_values(&[org_id]).set(*count); + } +} + +/// Update business metrics from database (with caching) +#[cfg(feature = "enable_metrics")] +pub async fn update_business_metrics(conn: &mut DbConn) -> Result<(), Error> { + // Check if cache is still valid + if is_cache_valid() { + // Apply cached metrics without DB query + if let Ok(cache) = BUSINESS_METRICS_CACHE.read() { + if let Some(ref cached) = *cache { + apply_cached_metrics(cached); + return Ok(()); + } + } + } + + use crate::db::models::*; + use std::collections::HashMap; + + // Count users + let enabled_users = User::count_enabled(conn).await; + let disabled_users = User::count_disabled(conn).await; + + // Count organizations + let organizations_vec = Organization::get_all(conn).await; + let active_orgs = organizations_vec.len() as i64; + + // Count vault items by type and organization + let vault_counts = Cipher::count_by_type_and_org(conn).await; + + // Count collections per organization + let mut collection_counts: HashMap = HashMap::new(); + for org in &organizations_vec { + let count = Collection::count_by_org(&org.uuid, conn).await; + collection_counts.insert(org.uuid.to_string(), count); + } + + // Create cache entry + let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); + let cache = BusinessMetricsCache { + timestamp: now, + users_enabled: enabled_users, + users_disabled: disabled_users, + organizations: active_orgs, + vault_counts, + collection_counts, + }; + + // Update cache and apply metrics + update_cached_metrics(cache.clone()); + apply_cached_metrics(&cache); + + Ok(()) +} + +/// Initialize build info metrics +#[cfg(feature = "enable_metrics")] +pub fn init_build_info() { + let version = crate::VERSION.unwrap_or("unknown"); + BUILD_INFO.with_label_values(&[version, "unknown", "unknown"]).set(1); +} + +/// Update system uptime +#[cfg(feature = "enable_metrics")] +pub fn update_uptime(start_time: SystemTime) { + if let Ok(elapsed) = start_time.elapsed() { + let version = crate::VERSION.unwrap_or("unknown"); + UPTIME_SECONDS.with_label_values(&[version]).set(elapsed.as_secs_f64()); + } +} + +/// Gather all metrics and return as Prometheus text format +#[cfg(feature = "enable_metrics")] +pub fn gather_metrics() -> Result { + let encoder = TextEncoder::new(); + let metric_families = prometheus::gather(); + let mut output = Vec::new(); + if let Err(e) = encoder.encode(&metric_families, &mut output) { + return Err(Error::new(format!("Failed to encode metrics: {}", e), "")); + } + match String::from_utf8(output) { + Ok(s) => Ok(s), + Err(e) => Err(Error::new(format!("Failed to convert metrics to string: {}", e), "")), + } +} + +// No-op implementations when metrics are disabled +#[cfg(not(feature = "enable_metrics"))] +#[allow(dead_code)] +pub fn increment_http_requests(_method: &str, _path: &str, _status: u16) {} + +#[cfg(not(feature = "enable_metrics"))] +#[allow(dead_code)] +pub fn observe_http_request_duration(_method: &str, _path: &str, _duration_seconds: f64) {} + +#[cfg(not(feature = "enable_metrics"))] +#[allow(dead_code)] +pub fn update_db_connections(_database: &str, _active: i64, _idle: i64) {} + +#[cfg(not(feature = "enable_metrics"))] +pub fn increment_auth_attempts(_method: &str, _status: &str) {} + +#[cfg(not(feature = "enable_metrics"))] +pub async fn update_business_metrics(_conn: &mut DbConn) -> Result<(), Error> { + Ok(()) +} + +#[cfg(not(feature = "enable_metrics"))] +pub fn init_build_info() {} + +#[cfg(not(feature = "enable_metrics"))] +#[allow(dead_code)] +pub fn update_uptime(_start_time: SystemTime) {} + +#[cfg(not(feature = "enable_metrics"))] +pub fn gather_metrics() -> Result { + Ok("Metrics not enabled".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "enable_metrics")] + mod metrics_enabled_tests { + use super::*; + + #[test] + fn test_http_metrics_collection() { + increment_http_requests("GET", "/api/sync", 200); + increment_http_requests("POST", "/api/accounts/register", 201); + increment_http_requests("GET", "/api/sync", 500); + observe_http_request_duration("GET", "/api/sync", 0.150); + observe_http_request_duration("POST", "/api/accounts/register", 0.300); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_http_requests_total")); + assert!(metrics.contains("vaultwarden_http_request_duration_seconds")); + } + + #[test] + fn test_database_metrics_collection() { + update_db_connections("sqlite", 5, 10); + update_db_connections("postgresql", 8, 2); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_db_connections_active")); + assert!(metrics.contains("vaultwarden_db_connections_idle")); + } + + #[test] + fn test_authentication_metrics() { + increment_auth_attempts("password", "success"); + increment_auth_attempts("password", "failed"); + increment_auth_attempts("webauthn", "success"); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_auth_attempts_total")); + assert!(metrics.contains("method=\"password\"")); + assert!(metrics.contains("status=\"success\"")); + assert!(metrics.contains("status=\"failed\"")); + } + + #[test] + fn test_build_info_initialization() { + init_build_info(); + let start_time = SystemTime::now(); + update_uptime(start_time); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_build_info")); + assert!(metrics.contains("vaultwarden_uptime_seconds")); + } + + #[test] + fn test_metrics_gathering() { + increment_http_requests("GET", "/api/sync", 200); + update_db_connections("sqlite", 1, 5); + init_build_info(); + + let metrics_output = gather_metrics(); + assert!(metrics_output.is_ok(), "gather_metrics should succeed"); + + let metrics_text = metrics_output.unwrap(); + assert!(!metrics_text.is_empty(), "metrics output should not be empty"); + assert!(metrics_text.contains("# HELP"), "metrics should have HELP lines"); + assert!(metrics_text.contains("# TYPE"), "metrics should have TYPE lines"); + assert!(metrics_text.contains("vaultwarden_"), "metrics should contain vaultwarden prefix"); + } + + #[tokio::test] + async fn test_business_metrics_collection_noop() { + init_build_info(); + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_"), "Business metrics should be accessible"); + } + + #[test] + fn test_path_normalization() { + increment_http_requests("GET", "/api/sync", 200); + increment_http_requests("GET", "/api/accounts/123/profile", 200); + increment_http_requests("POST", "/api/organizations/456/users", 201); + increment_http_requests("PUT", "/api/ciphers/789", 200); + + let result = gather_metrics(); + assert!(result.is_ok(), "gather_metrics should succeed with various paths"); + + let metrics_text = result.unwrap(); + assert!(!metrics_text.is_empty(), "metrics output should not be empty"); + assert!(metrics_text.contains("vaultwarden_http_requests_total"), "should have http request metrics"); + } + + #[test] + fn test_concurrent_metrics_collection() { + use std::thread; + + let handles: Vec<_> = (0..10).map(|i| { + thread::spawn(move || { + increment_http_requests("GET", "/api/sync", 200); + observe_http_request_duration("GET", "/api/sync", 0.1 + (i as f64 * 0.01)); + update_db_connections("sqlite", i, 10 - i); + }) + }).collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + let result = gather_metrics(); + assert!(result.is_ok(), "metrics collection should be thread-safe"); + assert!(!result.unwrap().is_empty(), "concurrent access should not corrupt metrics"); + } + } + + #[cfg(not(feature = "enable_metrics"))] + mod metrics_disabled_tests { + use super::*; + + #[test] + fn test_no_op_implementations() { + increment_http_requests("GET", "/api/sync", 200); + observe_http_request_duration("GET", "/api/sync", 0.150); + update_db_connections("sqlite", 5, 10); + increment_auth_attempts("password", "success"); + init_build_info(); + + let start_time = SystemTime::now(); + update_uptime(start_time); + + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should return ok"); + assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message"); + } + + #[tokio::test] + async fn test_business_metrics_no_op() { + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should not panic"); + assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message"); + } + + #[test] + fn test_concurrent_no_op_calls() { + use std::thread; + + let handles: Vec<_> = (0..5).map(|i| { + thread::spawn(move || { + increment_http_requests("GET", "/test", 200); + observe_http_request_duration("GET", "/test", 0.1); + update_db_connections("test", i, 5 - i); + increment_auth_attempts("password", "success"); + }) + }).collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should be thread-safe"); + assert_eq!(result.unwrap(), "Metrics not enabled", "disabled metrics should always return same message"); + } + } +} diff --git a/src/metrics_test.rs b/src/metrics_test.rs new file mode 100644 index 00000000..c173d28a --- /dev/null +++ b/src/metrics_test.rs @@ -0,0 +1,177 @@ +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + use tokio::time::sleep; + + #[cfg(feature = "enable_metrics")] + mod metrics_enabled_tests { + use super::*; + + #[test] + fn test_http_metrics_collection() { + increment_http_requests("GET", "/api/sync", 200); + increment_http_requests("POST", "/api/accounts/register", 201); + increment_http_requests("GET", "/api/sync", 500); + observe_http_request_duration("GET", "/api/sync", 0.150); + observe_http_request_duration("POST", "/api/accounts/register", 0.300); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_http_requests_total")); + assert!(metrics.contains("vaultwarden_http_request_duration_seconds")); + } + + #[test] + fn test_database_metrics_collection() { + update_db_connections("sqlite", 5, 10); + update_db_connections("postgresql", 8, 2); + observe_db_query_duration("select", 0.025); + observe_db_query_duration("insert", 0.045); + observe_db_query_duration("update", 0.030); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_db_connections_active")); + assert!(metrics.contains("vaultwarden_db_connections_idle")); + assert!(metrics.contains("vaultwarden_db_query_duration_seconds")); + } + + #[test] + fn test_authentication_metrics() { + increment_auth_attempts("password", "success"); + increment_auth_attempts("password", "failed"); + increment_auth_attempts("webauthn", "success"); + increment_auth_attempts("2fa", "failed"); + update_user_sessions("authenticated", 150); + update_user_sessions("anonymous", 5); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_auth_attempts_total")); + assert!(metrics.contains("vaultwarden_user_sessions_active")); + } + + #[test] + fn test_build_info_initialization() { + init_build_info(); + let start_time = std::time::SystemTime::now(); + update_uptime(start_time); + + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_build_info")); + assert!(metrics.contains("vaultwarden_uptime_seconds")); + } + + #[test] + fn test_metrics_gathering() { + increment_http_requests("GET", "/api/sync", 200); + update_db_connections("sqlite", 1, 5); + init_build_info(); + + let metrics_output = gather_metrics(); + assert!(metrics_output.is_ok(), "gather_metrics should succeed"); + + let metrics_text = metrics_output.unwrap(); + assert!(!metrics_text.is_empty(), "metrics output should not be empty"); + assert!(metrics_text.contains("# HELP"), "metrics should have HELP lines"); + assert!(metrics_text.contains("# TYPE"), "metrics should have TYPE lines"); + assert!(metrics_text.contains("vaultwarden_"), "metrics should contain vaultwarden prefix"); + } + + #[tokio::test] + async fn test_business_metrics_collection_noop() { + // Business metrics require database access, which cannot be easily mocked in unit tests. + // This test verifies that the async function exists and can be called without panicking. + // Integration tests would provide database access and verify metrics are actually updated. + init_build_info(); + let metrics = gather_metrics().expect("Failed to gather metrics"); + assert!(metrics.contains("vaultwarden_"), "Business metrics should be accessible"); + } + + #[test] + fn test_path_normalization() { + increment_http_requests("GET", "/api/sync", 200); + increment_http_requests("GET", "/api/accounts/123/profile", 200); + increment_http_requests("POST", "/api/organizations/456/users", 201); + increment_http_requests("PUT", "/api/ciphers/789", 200); + + let result = gather_metrics(); + assert!(result.is_ok(), "gather_metrics should succeed with various paths"); + + let metrics_text = result.unwrap(); + assert!(!metrics_text.is_empty(), "metrics output should not be empty"); + assert!(metrics_text.contains("vaultwarden_http_requests_total"), "should have http request metrics"); + } + + #[test] + fn test_concurrent_metrics_collection() { + use std::thread; + + let handles: Vec<_> = (0..10).map(|i| { + thread::spawn(move || { + increment_http_requests("GET", "/api/sync", 200); + observe_http_request_duration("GET", "/api/sync", 0.1 + (i as f64 * 0.01)); + update_db_connections("sqlite", i, 10 - i); + }) + }).collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + let result = gather_metrics(); + assert!(result.is_ok(), "metrics collection should be thread-safe"); + assert!(!result.unwrap().is_empty(), "concurrent access should not corrupt metrics"); + } + } + + #[cfg(not(feature = "enable_metrics"))] + mod metrics_disabled_tests { + use super::*; + + #[test] + fn test_no_op_implementations() { + increment_http_requests("GET", "/api/sync", 200); + observe_http_request_duration("GET", "/api/sync", 0.150); + update_db_connections("sqlite", 5, 10); + observe_db_query_duration("select", 0.025); + increment_auth_attempts("password", "success"); + update_user_sessions("authenticated", 150); + init_build_info(); + + let start_time = std::time::SystemTime::now(); + update_uptime(start_time); + + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should return ok"); + assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message"); + } + + #[tokio::test] + async fn test_business_metrics_no_op() { + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should not panic"); + assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message"); + } + + #[test] + fn test_concurrent_no_op_calls() { + use std::thread; + + let handles: Vec<_> = (0..5).map(|i| { + thread::spawn(move || { + increment_http_requests("GET", "/test", 200); + observe_http_request_duration("GET", "/test", 0.1); + update_db_connections("test", i, 5 - i); + increment_auth_attempts("password", "success"); + }) + }).collect(); + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + let result = gather_metrics(); + assert!(result.is_ok(), "disabled metrics should be thread-safe"); + assert_eq!(result.unwrap(), "Metrics not enabled", "disabled metrics should always return same message"); + } + } +} \ No newline at end of file