Browse Source

Merge 672a1e5c72 into cc80f689ed

pull/6202/merge
Ross Golder 2 days ago
committed by GitHub
parent
commit
7be09873b2
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 16
      Cargo.lock
  2. 8
      Cargo.toml
  3. 1
      README.md
  4. 2
      docker/Dockerfile.alpine
  5. 2
      docker/Dockerfile.debian
  6. 4
      docker/Dockerfile.j2
  7. 9
      src/api/identity.rs
  8. 98
      src/api/metrics.rs
  9. 176
      src/api/middleware.rs
  10. 4
      src/api/mod.rs
  11. 4
      src/api/web.rs
  12. 34
      src/config.rs
  13. 68
      src/db/models/cipher.rs
  14. 10
      src/db/models/organization.rs
  15. 22
      src/db/models/user.rs
  16. 34
      src/main.rs
  17. 456
      src/metrics.rs
  18. 177
      src/metrics_test.rs

16
Cargo.lock

@ -3875,6 +3875,20 @@ dependencies = [
"yansi", "yansi",
] ]
[[package]]
name = "prometheus"
version = "0.13.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1"
dependencies = [
"cfg-if",
"fnv",
"lazy_static",
"memchr",
"parking_lot",
"thiserror 1.0.69",
]
[[package]] [[package]]
name = "psl-types" name = "psl-types"
version = "2.0.11" version = "2.0.11"
@ -5824,12 +5838,14 @@ dependencies = [
"mini-moka", "mini-moka",
"num-derive", "num-derive",
"num-traits", "num-traits",
"once_cell",
"opendal", "opendal",
"openidconnect", "openidconnect",
"openssl", "openssl",
"pastey 0.2.1", "pastey 0.2.1",
"percent-encoding", "percent-encoding",
"pico-args", "pico-args",
"prometheus",
"rand 0.9.2", "rand 0.9.2",
"regex", "regex",
"reqsign", "reqsign",

8
Cargo.toml

@ -37,6 +37,8 @@ vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc # Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds # This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"] enable_mimalloc = ["dep:mimalloc"]
# Enable Prometheus metrics endpoint
enable_metrics = ["dep:prometheus"]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"] s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"]
# OIDC specific features # OIDC specific features
@ -77,6 +79,9 @@ rmpv = "1.3.1" # MessagePack library
# Concurrent HashMap used for WebSocket messaging and favicons # Concurrent HashMap used for WebSocket messaging and favicons
dashmap = "6.1.0" dashmap = "6.1.0"
# Lazy static initialization
once_cell = "1.20.2"
# Async futures # Async futures
futures = "0.3.31" futures = "0.3.31"
tokio = { version = "1.49.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] } tokio = { version = "1.49.0", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] }
@ -182,6 +187,9 @@ semver = "1.0.27"
# Mainly used for the musl builds, since the default musl malloc is very slow # Mainly used for the musl builds, since the default musl malloc is very slow
mimalloc = { version = "0.1.48", features = ["secure"], default-features = false, optional = true } mimalloc = { version = "0.1.48", features = ["secure"], default-features = false, optional = true }
# Prometheus metrics
prometheus = { version = "0.13.1", default-features = false, optional = true }
which = "8.0.0" which = "8.0.0"
# Argon2 library with support for the PHC format # Argon2 library with support for the PHC format

1
README.md

@ -52,6 +52,7 @@ A nearly complete implementation of the Bitwarden Client API is provided, includ
[Duo](https://bitwarden.com/help/setup-two-step-login-duo/) [Duo](https://bitwarden.com/help/setup-two-step-login-duo/)
* [Emergency Access](https://bitwarden.com/help/emergency-access/) * [Emergency Access](https://bitwarden.com/help/emergency-access/)
* [Vaultwarden Admin Backend](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-admin-page) * [Vaultwarden Admin Backend](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-admin-page)
* [Prometheus Metrics](https://github.com/dani-garcia/vaultwarden/wiki/Metrics) - Optional monitoring and observability with secure endpoint
* [Modified Web Vault client](https://github.com/dani-garcia/bw_web_builds) (Bundled within our containers) * [Modified Web Vault client](https://github.com/dani-garcia/bw_web_builds) (Bundled within our containers)
<br> <br>

2
docker/Dockerfile.alpine

@ -82,7 +82,7 @@ ARG CARGO_PROFILE=release
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
# Enable MiMalloc to improve performance on Alpine builds # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc ARG DB=sqlite,mysql,postgresql,enable_mimalloc,enable_metrics
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder

2
docker/Dockerfile.debian

@ -116,7 +116,7 @@ COPY ./macros ./macros
ARG CARGO_PROFILE=release ARG CARGO_PROFILE=release
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
ARG DB=sqlite,mysql,postgresql ARG DB=sqlite,mysql,postgresql,enable_metrics
# Builds your dependencies and removes the # Builds your dependencies and removes the
# dummy project, except the target folder # dummy project, except the target folder

4
docker/Dockerfile.j2

@ -144,10 +144,10 @@ ARG CARGO_PROFILE=release
# Configure the DB ARG as late as possible to not invalidate the cached layers above # Configure the DB ARG as late as possible to not invalidate the cached layers above
{% if base == "debian" %} {% if base == "debian" %}
ARG DB=sqlite,mysql,postgresql ARG DB=sqlite,mysql,postgresql,enable_metrics
{% elif base == "alpine" %} {% elif base == "alpine" %}
# Enable MiMalloc to improve performance on Alpine builds # Enable MiMalloc to improve performance on Alpine builds
ARG DB=sqlite,mysql,postgresql,enable_mimalloc ARG DB=sqlite,mysql,postgresql,enable_mimalloc,enable_metrics
{% endif %} {% endif %}
# Builds your dependencies and removes the # Builds your dependencies and removes the

9
src/api/identity.rs

@ -32,7 +32,7 @@ use crate::{
error::MapResult, error::MapResult,
mail, sso, mail, sso,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState}, sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
util, CONFIG, util, CONFIG, metrics,
}; };
pub fn routes() -> Vec<Route> { pub fn routes() -> Vec<Route> {
@ -60,7 +60,8 @@ async fn login(
let mut user_id: Option<UserId> = None; let mut user_id: Option<UserId> = None;
let login_result = match data.grant_type.as_ref() { let auth_method = data.grant_type.clone();
let login_result = match auth_method.as_ref() {
"refresh_token" => { "refresh_token" => {
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?; _check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, &conn, &client_header.ip).await _refresh_login(data, &conn, &client_header.ip).await
@ -104,6 +105,10 @@ async fn login(
t => err!("Invalid type", t), t => err!("Invalid type", t),
}; };
// Record authentication metrics
let auth_status = if login_result.is_ok() { "success" } else { "failed" };
metrics::increment_auth_attempts(&auth_method, auth_status);
if let Some(user_id) = user_id { if let Some(user_id) = user_id {
match &login_result { match &login_result {
Ok(_) => { Ok(_) => {

98
src/api/metrics.rs

@ -0,0 +1,98 @@
use rocket::{
request::{FromRequest, Outcome, Request},
response::content::RawText,
Route,
};
use crate::{auth::ClientIp, db::DbConn, CONFIG};
pub fn routes() -> Vec<Route> {
if CONFIG.enable_metrics() {
routes![get_metrics]
} else {
Vec::new()
}
}
pub struct MetricsToken {
_ip: ClientIp,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for MetricsToken {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
};
let Some(configured_token) = CONFIG.metrics_token() else {
return Outcome::Success(Self { _ip: ip });
};
let provided_token = request
.headers()
.get_one("Authorization")
.and_then(|auth| auth.strip_prefix("Bearer "))
.or_else(|| request.query_value::<&str>("token").and_then(|result| result.ok()));
match provided_token {
Some(token) => {
if validate_metrics_token(token, &configured_token) {
Outcome::Success(Self { _ip: ip })
} else {
err_handler!("Invalid metrics token")
}
}
None => err_handler!("Metrics token required"),
}
}
}
fn validate_metrics_token(provided: &str, configured: &str) -> bool {
if configured.starts_with("$argon2") {
use argon2::password_hash::PasswordVerifier;
match argon2::password_hash::PasswordHash::new(configured) {
Ok(hash) => argon2::Argon2::default().verify_password(provided.trim().as_bytes(), &hash).is_ok(),
Err(e) => {
error!("Invalid Argon2 PHC in METRICS_TOKEN: {e}");
false
}
}
} else {
crate::crypto::ct_eq(configured.trim(), provided.trim())
}
}
/// Prometheus metrics endpoint
#[get("/")]
async fn get_metrics(_token: MetricsToken, mut conn: DbConn) -> Result<RawText<String>, crate::error::Error> {
if let Err(e) = crate::metrics::update_business_metrics(&mut conn).await {
err!("Failed to update business metrics", e.to_string());
}
match crate::metrics::gather_metrics() {
Ok(metrics) => Ok(RawText(metrics)),
Err(e) => err!("Failed to gather metrics", e.to_string()),
}
}
/// Health check endpoint that also updates some basic metrics
#[cfg(feature = "enable_metrics")]
pub async fn update_health_metrics(_conn: &mut DbConn) {
// Update basic system metrics
use std::time::SystemTime;
static START_TIME: std::sync::OnceLock<SystemTime> = std::sync::OnceLock::new();
let start_time = *START_TIME.get_or_init(SystemTime::now);
crate::metrics::update_uptime(start_time);
// Update database connection metrics
// Note: This is a simplified version - in production you'd want to get actual pool stats
crate::metrics::update_db_connections("main", 1, 0);
}
#[cfg(not(feature = "enable_metrics"))]
pub async fn update_health_metrics(_conn: &mut DbConn) {}

176
src/api/middleware.rs

@ -0,0 +1,176 @@
/// Metrics middleware for automatic HTTP request instrumentation
use rocket::{
fairing::{Fairing, Info, Kind},
Data, Request, Response,
};
use std::time::Instant;
pub struct MetricsFairing;
#[rocket::async_trait]
impl Fairing for MetricsFairing {
fn info(&self) -> Info {
Info {
name: "Metrics Collection",
kind: Kind::Request | Kind::Response,
}
}
async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
req.local_cache(|| RequestTimer {
start_time: Instant::now(),
});
}
async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) {
let timer = req.local_cache(|| RequestTimer {
start_time: Instant::now(),
});
let duration = timer.start_time.elapsed();
let method = req.method().as_str();
let path = normalize_path(req.uri().path().as_str());
let status = res.status().code;
// Record metrics
crate::metrics::increment_http_requests(method, &path, status);
crate::metrics::observe_http_request_duration(method, &path, duration.as_secs_f64());
}
}
struct RequestTimer {
start_time: Instant,
}
/// Normalize paths to avoid high cardinality metrics
/// Convert dynamic segments to static labels
fn normalize_path(path: &str) -> String {
let segments: Vec<&str> = path.split('/').collect();
let mut normalized = Vec::new();
for segment in segments {
if segment.is_empty() {
continue;
}
let normalized_segment = if is_uuid(segment) {
"{id}"
} else if is_hex_hash(segment) {
"{hash}"
} else if segment.chars().all(|c| c.is_ascii_digit()) {
"{number}"
} else {
segment
};
normalized.push(normalized_segment);
}
if normalized.is_empty() {
"/".to_string()
} else {
format!("/{}", normalized.join("/"))
}
}
/// Check if a string is a hex hash (32+ hex chars, typical for SHA256, MD5, etc)
fn is_hex_hash(s: &str) -> bool {
s.len() >= 32 && s.chars().all(|c| c.is_ascii_hexdigit())
}
/// Check if a string looks like a UUID
fn is_uuid(s: &str) -> bool {
s.len() == 36
&& s.chars().enumerate().all(|(i, c)| match i {
8 | 13 | 18 | 23 => c == '-',
_ => c.is_ascii_hexdigit(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_path_preserves_static_routes() {
assert_eq!(normalize_path("/api/accounts"), "/api/accounts");
assert_eq!(normalize_path("/api/sync"), "/api/sync");
assert_eq!(normalize_path("/icons"), "/icons");
}
#[test]
fn test_normalize_path_replaces_uuid() {
let uuid = "12345678-1234-5678-9012-123456789012";
assert_eq!(
normalize_path(&format!("/api/accounts/{uuid}")),
"/api/accounts/{id}"
);
assert_eq!(
normalize_path(&format!("/ciphers/{uuid}")),
"/ciphers/{id}"
);
}
#[test]
fn test_normalize_path_replaces_sha256_hash() {
// SHA256 hashes are 64 hex characters
let sha256 = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
assert_eq!(
normalize_path(&format!("/attachments/{sha256}")),
"/attachments/{hash}"
);
}
#[test]
fn test_normalize_path_does_not_replace_short_hex() {
// Only consider 32+ char hex strings as hashes
assert_eq!(normalize_path("/api/hex123"), "/api/hex123");
assert_eq!(normalize_path("/test/abc"), "/test/abc");
assert_eq!(normalize_path("/api/abcdef1234567890"), "/api/abcdef1234567890"); // 16 chars
assert_eq!(normalize_path("/files/0123456789abcdef"), "/files/0123456789abcdef"); // 16 chars
}
#[test]
fn test_normalize_path_replaces_numbers() {
assert_eq!(normalize_path("/api/organizations/123"), "/api/organizations/{number}");
assert_eq!(normalize_path("/users/456/profile"), "/users/{number}/profile");
}
#[test]
fn test_normalize_path_root() {
assert_eq!(normalize_path("/"), "/");
}
#[test]
fn test_normalize_path_empty_segments() {
assert_eq!(normalize_path("//api//accounts"), "/api/accounts");
}
#[test]
fn test_is_uuid_valid() {
assert!(is_uuid("12345678-1234-5678-9012-123456789012"));
assert!(is_uuid("00000000-0000-0000-0000-000000000000"));
assert!(is_uuid("ffffffff-ffff-ffff-ffff-ffffffffffff"));
}
#[test]
fn test_is_uuid_invalid_format() {
assert!(!is_uuid("not-a-uuid"));
assert!(!is_uuid("12345678123456781234567812345678"));
assert!(!is_uuid("123"));
assert!(!is_uuid(""));
assert!(!is_uuid("12345678-1234-5678-9012-12345678901")); // Too short
assert!(!is_uuid("12345678-1234-5678-9012-1234567890123")); // Too long
}
#[test]
fn test_is_uuid_invalid_characters() {
assert!(!is_uuid("12345678-1234-5678-9012-12345678901z"));
assert!(!is_uuid("g2345678-1234-5678-9012-123456789012"));
}
#[test]
fn test_is_uuid_invalid_dash_positions() {
assert!(!is_uuid("12345678-1234-56789012-123456789012"));
assert!(!is_uuid("12345678-1234-5678-90121-23456789012"));
}
}

4
src/api/mod.rs

@ -2,6 +2,8 @@ mod admin;
pub mod core; pub mod core;
mod icons; mod icons;
mod identity; mod identity;
mod metrics;
mod middleware;
mod notifications; mod notifications;
mod push; mod push;
mod web; mod web;
@ -22,6 +24,8 @@ pub use crate::api::{
core::{event_cleanup_job, events_routes as core_events_routes}, core::{event_cleanup_job, events_routes as core_events_routes},
icons::routes as icons_routes, icons::routes as icons_routes,
identity::routes as identity_routes, identity::routes as identity_routes,
metrics::routes as metrics_routes,
middleware::MetricsFairing,
notifications::routes as notifications_routes, notifications::routes as notifications_routes,
notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS}, notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS},
push::{ push::{

4
src/api/web.rs

@ -179,7 +179,9 @@ async fn attachments(cipher_id: CipherId, file_id: AttachmentId, token: String)
// We use DbConn here to let the alive healthcheck also verify the database connection. // We use DbConn here to let the alive healthcheck also verify the database connection.
use crate::db::DbConn; use crate::db::DbConn;
#[get("/alive")] #[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> { async fn alive(mut conn: DbConn) -> Json<String> {
// Update basic health metrics if metrics are enabled
let _ = crate::api::metrics::update_health_metrics(&mut conn).await;
now() now()
} }

34
src/config.rs

@ -918,6 +918,16 @@ make_config! {
/// Auto-enable 2FA (Know the risks!) |> Automatically setup email 2FA as fallback provider when needed /// Auto-enable 2FA (Know the risks!) |> Automatically setup email 2FA as fallback provider when needed
email_2fa_auto_fallback: bool, true, def, false; email_2fa_auto_fallback: bool, true, def, false;
}, },
/// Metrics Settings
metrics {
/// Enable metrics endpoint |> Enable Prometheus metrics endpoint at /metrics
enable_metrics: bool, true, def, false;
/// Metrics token |> Optional token to secure the /metrics endpoint. If not set, endpoint is public when enabled.
metrics_token: Pass, true, option;
/// Business metrics cache timeout |> Number of seconds to cache business metrics before refreshing from database
metrics_business_cache_seconds: u64, true, def, 300;
},
} }
fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
@ -1266,6 +1276,30 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> {
println!("[WARNING] Secure Note size limit is increased to 100_000!"); println!("[WARNING] Secure Note size limit is increased to 100_000!");
println!("[WARNING] This could cause issues with clients. Also exports will not work on Bitwarden servers!."); println!("[WARNING] This could cause issues with clients. Also exports will not work on Bitwarden servers!.");
} }
// Validate metrics configuration
if cfg.enable_metrics {
if let Some(ref token) = cfg.metrics_token {
if token.starts_with("$argon2") {
if let Err(e) = argon2::password_hash::PasswordHash::new(token) {
err!(format!("The configured Argon2 PHC in `METRICS_TOKEN` is invalid: '{e}'"))
}
} else if token.trim().is_empty() {
err!("`METRICS_TOKEN` cannot be empty when metrics are enabled");
} else {
println!(
"[NOTICE] You are using a plain text `METRICS_TOKEN` which is less secure.\n\
Please consider generating a secure Argon2 PHC string by using `vaultwarden hash`.\n"
);
}
} else {
println!(
"[WARNING] Metrics endpoint is enabled without authentication. This may expose sensitive information."
);
println!("[WARNING] Consider setting `METRICS_TOKEN` to secure the endpoint.");
}
}
Ok(()) Ok(())
} }

68
src/db/models/cipher.rs

@ -135,6 +135,24 @@ use crate::db::DbConn;
use crate::api::EmptyResult; use crate::api::EmptyResult;
use crate::error::MapResult; use crate::error::MapResult;
#[derive(QueryableByName)]
struct CipherCount {
#[diesel(sql_type = diesel::sql_types::Integer)]
atype: i32,
#[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64,
}
#[derive(QueryableByName)]
struct CipherOrgCount {
#[diesel(sql_type = diesel::sql_types::Integer)]
atype: i32,
#[diesel(sql_type = diesel::sql_types::Text)]
organization_uuid: String,
#[diesel(sql_type = diesel::sql_types::BigInt)]
count: i64,
}
/// Database methods /// Database methods
impl Cipher { impl Cipher {
pub async fn to_json( pub async fn to_json(
@ -967,6 +985,56 @@ impl Cipher {
}} }}
} }
pub async fn count_by_type_and_org(conn: &DbConn) -> std::collections::HashMap<(String, String), i64> {
use std::collections::HashMap;
db_run! { conn: {
// Count personal ciphers (organization_uuid IS NULL)
let personal_results: Vec<CipherCount> = diesel::sql_query(
"SELECT atype, COUNT(*) as count FROM ciphers WHERE deleted_at IS NULL AND organization_uuid IS NULL GROUP BY atype"
)
.load(conn)
.expect("Error counting personal ciphers");
// Count organization ciphers (organization_uuid IS NOT NULL)
let org_results: Vec<CipherOrgCount> = diesel::sql_query(
"SELECT atype, organization_uuid, COUNT(*) as count FROM ciphers WHERE deleted_at IS NULL AND organization_uuid IS NOT NULL GROUP BY atype, organization_uuid"
)
.load(conn)
.expect("Error counting organization ciphers");
let mut counts = HashMap::new();
for result in personal_results {
let cipher_type = match result.atype {
1 => "login",
2 => "note",
3 => "card",
4 => "identity",
_ => "unknown",
};
counts.insert((cipher_type.to_string(), "personal".to_string()), result.count);
}
for result in org_results {
let cipher_type = match result.atype {
1 => "login",
2 => "note",
3 => "card",
4 => "identity",
_ => "unknown",
};
counts.insert((cipher_type.to_string(), result.organization_uuid), result.count);
}
counts
}}
}
pub async fn find_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
ciphers::table
.load::<Self>(conn)
.expect("Error loading ciphers")
}}
}
pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> { pub async fn get_collections(&self, user_uuid: UserId, conn: &DbConn) -> Vec<CollectionId> {
if CONFIG.org_groups_enabled() { if CONFIG.org_groups_enabled() {
db_run! { conn: { db_run! { conn: {

10
src/db/models/organization.rs

@ -403,6 +403,16 @@ impl Organization {
}} }}
} }
pub async fn count(conn: &DbConn) -> i64 {
db_run! { conn: {
organizations::table
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn get_all(conn: &DbConn) -> Vec<Self> { pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: { db_run! { conn: {
organizations::table organizations::table

22
src/db/models/user.rs

@ -409,6 +409,28 @@ impl User {
None None
} }
pub async fn count_enabled(conn: &DbConn) -> i64 {
db_run! { conn: {
users::table
.filter(users::enabled.eq(true))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn count_disabled(conn: &DbConn) -> i64 {
db_run! { conn: {
users::table
.filter(users::enabled.eq(false))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> { pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
db_run! { conn: { db_run! { conn: {
users::table users::table

34
src/main.rs

@ -54,6 +54,7 @@ mod crypto;
mod db; mod db;
mod http_client; mod http_client;
mod mail; mod mail;
mod metrics;
mod ratelimit; mod ratelimit;
mod sso; mod sso;
mod sso_client; mod sso_client;
@ -89,6 +90,17 @@ async fn main() -> Result<(), Error> {
db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).await.unwrap();
db::models::TwoFactor::migrate_credential_to_passkey(&pool.get().await.unwrap()).await.unwrap(); db::models::TwoFactor::migrate_credential_to_passkey(&pool.get().await.unwrap()).await.unwrap();
// Initialize metrics if enabled
if CONFIG.enable_metrics() {
metrics::init_build_info();
info!("Metrics endpoint enabled at /metrics");
if CONFIG.metrics_token().is_some() {
info!("Metrics endpoint secured with token");
} else {
warn!("Metrics endpoint is publicly accessible");
}
}
let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug); let extra_debug = matches!(level, log::LevelFilter::Trace | log::LevelFilter::Debug);
launch_rocket(pool, extra_debug).await // Blocks until program termination. launch_rocket(pool, extra_debug).await // Blocks until program termination.
} }
@ -567,14 +579,21 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
// If adding more paths here, consider also adding them to // If adding more paths here, consider also adding them to
// crate::utils::LOGGED_ROUTES to make sure they appear in the log // crate::utils::LOGGED_ROUTES to make sure they appear in the log
let instance = rocket::custom(config) let mut instance = rocket::custom(config)
.mount([basepath, "/"].concat(), api::web_routes()) .mount([basepath, "/"].concat(), api::web_routes())
.mount([basepath, "/api"].concat(), api::core_routes()) .mount([basepath, "/api"].concat(), api::core_routes())
.mount([basepath, "/admin"].concat(), api::admin_routes()) .mount([basepath, "/admin"].concat(), api::admin_routes())
.mount([basepath, "/events"].concat(), api::core_events_routes()) .mount([basepath, "/events"].concat(), api::core_events_routes())
.mount([basepath, "/identity"].concat(), api::identity_routes()) .mount([basepath, "/identity"].concat(), api::identity_routes())
.mount([basepath, "/icons"].concat(), api::icons_routes()) .mount([basepath, "/icons"].concat(), api::icons_routes())
.mount([basepath, "/notifications"].concat(), api::notifications_routes()) .mount([basepath, "/notifications"].concat(), api::notifications_routes());
// Conditionally mount metrics routes if enabled
if CONFIG.enable_metrics() {
instance = instance.mount([basepath, "/metrics"].concat(), api::metrics_routes());
}
let mut rocket_instance = instance
.register([basepath, "/"].concat(), api::web_catchers()) .register([basepath, "/"].concat(), api::web_catchers())
.register([basepath, "/api"].concat(), api::core_catchers()) .register([basepath, "/api"].concat(), api::core_catchers())
.register([basepath, "/admin"].concat(), api::admin_catchers()) .register([basepath, "/admin"].concat(), api::admin_catchers())
@ -583,9 +602,14 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
.manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS)) .manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS))
.attach(util::AppHeaders()) .attach(util::AppHeaders())
.attach(util::Cors()) .attach(util::Cors())
.attach(util::BetterLogging(extra_debug)) .attach(util::BetterLogging(extra_debug));
.ignite()
.await?; // Attach metrics fairing if metrics are enabled
if CONFIG.enable_metrics() {
rocket_instance = rocket_instance.attach(api::MetricsFairing);
}
let instance = rocket_instance.ignite().await?;
CONFIG.set_rocket_shutdown_handle(instance.shutdown()); CONFIG.set_rocket_shutdown_handle(instance.shutdown());

456
src/metrics.rs

@ -0,0 +1,456 @@
use std::time::SystemTime;
#[cfg(feature = "enable_metrics")]
use once_cell::sync::Lazy;
#[cfg(feature = "enable_metrics")]
use prometheus::{
register_gauge_vec, register_histogram_vec, register_int_counter_vec, register_int_gauge_vec, Encoder, GaugeVec,
HistogramVec, IntCounterVec, IntGaugeVec, TextEncoder,
};
use crate::{db::DbConn, error::Error};
#[cfg(feature = "enable_metrics")]
use crate::CONFIG;
#[cfg(feature = "enable_metrics")]
use std::sync::RwLock;
#[cfg(feature = "enable_metrics")]
use std::time::UNIX_EPOCH;
// HTTP request metrics
#[cfg(feature = "enable_metrics")]
static HTTP_REQUESTS_TOTAL: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"vaultwarden_http_requests_total",
"Total number of HTTP requests processed",
&["method", "path", "status"]
)
.unwrap()
});
#[cfg(feature = "enable_metrics")]
static HTTP_REQUEST_DURATION_SECONDS: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"vaultwarden_http_request_duration_seconds",
"HTTP request duration in seconds",
&["method", "path"],
vec![0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]
)
.unwrap()
});
// Database metrics
#[cfg(feature = "enable_metrics")]
static DB_CONNECTIONS_ACTIVE: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_db_connections_active", "Number of active database connections", &["database"])
.unwrap()
});
#[cfg(feature = "enable_metrics")]
static DB_CONNECTIONS_IDLE: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_db_connections_idle", "Number of idle database connections", &["database"])
.unwrap()
});
// Authentication metrics
#[cfg(feature = "enable_metrics")]
static AUTH_ATTEMPTS_TOTAL: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"vaultwarden_auth_attempts_total",
"Total number of authentication attempts",
&["method", "status"]
)
.unwrap()
});
// Business metrics
#[cfg(feature = "enable_metrics")]
static USERS_TOTAL: Lazy<IntGaugeVec> =
Lazy::new(|| register_int_gauge_vec!("vaultwarden_users_total", "Total number of users", &["status"]).unwrap());
#[cfg(feature = "enable_metrics")]
static ORGANIZATIONS_TOTAL: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_organizations_total", "Total number of organizations", &["status"]).unwrap()
});
#[cfg(feature = "enable_metrics")]
static VAULT_ITEMS_TOTAL: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_vault_items_total", "Total number of vault items", &["type", "organization"])
.unwrap()
});
#[cfg(feature = "enable_metrics")]
static COLLECTIONS_TOTAL: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_collections_total", "Total number of collections", &["organization"]).unwrap()
});
// System metrics
#[cfg(feature = "enable_metrics")]
static UPTIME_SECONDS: Lazy<GaugeVec> =
Lazy::new(|| register_gauge_vec!("vaultwarden_uptime_seconds", "Uptime in seconds", &["version"]).unwrap());
#[cfg(feature = "enable_metrics")]
static BUILD_INFO: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!("vaultwarden_build_info", "Build information", &["version", "revision", "branch"]).unwrap()
});
/// Increment HTTP request counter
#[cfg(feature = "enable_metrics")]
pub fn increment_http_requests(method: &str, path: &str, status: u16) {
HTTP_REQUESTS_TOTAL.with_label_values(&[method, path, &status.to_string()]).inc();
}
/// Observe HTTP request duration
#[cfg(feature = "enable_metrics")]
pub fn observe_http_request_duration(method: &str, path: &str, duration_seconds: f64) {
HTTP_REQUEST_DURATION_SECONDS.with_label_values(&[method, path]).observe(duration_seconds);
}
/// Update database connection metrics
#[cfg(feature = "enable_metrics")]
pub fn update_db_connections(database: &str, active: i64, idle: i64) {
DB_CONNECTIONS_ACTIVE.with_label_values(&[database]).set(active);
DB_CONNECTIONS_IDLE.with_label_values(&[database]).set(idle);
}
/// Increment authentication attempts (success/failure tracking)
/// Tracks authentication success/failure by method (password, client_credentials, SSO, etc.)
/// Called from src/api/identity.rs login() after each authentication attempt
#[cfg(feature = "enable_metrics")]
pub fn increment_auth_attempts(method: &str, status: &str) {
AUTH_ATTEMPTS_TOTAL.with_label_values(&[method, status]).inc();
}
/// Cached business metrics data
#[cfg(feature = "enable_metrics")]
#[derive(Clone)]
struct BusinessMetricsCache {
timestamp: u64,
users_enabled: i64,
users_disabled: i64,
organizations: i64,
vault_counts: std::collections::HashMap<(String, String), i64>,
collection_counts: std::collections::HashMap<String, i64>,
}
#[cfg(feature = "enable_metrics")]
static BUSINESS_METRICS_CACHE: Lazy<RwLock<Option<BusinessMetricsCache>>> = Lazy::new(|| RwLock::new(None));
/// Check if business metrics cache is still valid
#[cfg(feature = "enable_metrics")]
fn is_cache_valid() -> bool {
let cache_timeout = CONFIG.metrics_business_cache_seconds();
if let Ok(cache) = BUSINESS_METRICS_CACHE.read() {
if let Some(ref cached) = *cache {
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
return now - cached.timestamp < cache_timeout;
}
}
false
}
/// Update cached business metrics
#[cfg(feature = "enable_metrics")]
fn update_cached_metrics(cache: BusinessMetricsCache) {
if let Ok(mut cached) = BUSINESS_METRICS_CACHE.write() {
*cached = Some(cache);
}
}
/// Apply cached metrics to Prometheus gauges
#[cfg(feature = "enable_metrics")]
fn apply_cached_metrics(cache: &BusinessMetricsCache) {
USERS_TOTAL.with_label_values(&["enabled"]).set(cache.users_enabled);
USERS_TOTAL.with_label_values(&["disabled"]).set(cache.users_disabled);
ORGANIZATIONS_TOTAL.with_label_values(&["active"]).set(cache.organizations);
for ((cipher_type, org_label), count) in &cache.vault_counts {
VAULT_ITEMS_TOTAL.with_label_values(&[cipher_type, org_label]).set(*count);
}
for (org_id, count) in &cache.collection_counts {
COLLECTIONS_TOTAL.with_label_values(&[org_id]).set(*count);
}
}
/// Update business metrics from database (with caching)
#[cfg(feature = "enable_metrics")]
pub async fn update_business_metrics(conn: &mut DbConn) -> Result<(), Error> {
// Check if cache is still valid
if is_cache_valid() {
// Apply cached metrics without DB query
if let Ok(cache) = BUSINESS_METRICS_CACHE.read() {
if let Some(ref cached) = *cache {
apply_cached_metrics(cached);
return Ok(());
}
}
}
use crate::db::models::*;
use std::collections::HashMap;
// Count users
let enabled_users = User::count_enabled(conn).await;
let disabled_users = User::count_disabled(conn).await;
// Count organizations
let organizations_vec = Organization::get_all(conn).await;
let active_orgs = organizations_vec.len() as i64;
// Count vault items by type and organization
let vault_counts = Cipher::count_by_type_and_org(conn).await;
// Count collections per organization
let mut collection_counts: HashMap<String, i64> = HashMap::new();
for org in &organizations_vec {
let count = Collection::count_by_org(&org.uuid, conn).await;
collection_counts.insert(org.uuid.to_string(), count);
}
// Create cache entry
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
let cache = BusinessMetricsCache {
timestamp: now,
users_enabled: enabled_users,
users_disabled: disabled_users,
organizations: active_orgs,
vault_counts,
collection_counts,
};
// Update cache and apply metrics
update_cached_metrics(cache.clone());
apply_cached_metrics(&cache);
Ok(())
}
/// Initialize build info metrics
#[cfg(feature = "enable_metrics")]
pub fn init_build_info() {
let version = crate::VERSION.unwrap_or("unknown");
BUILD_INFO.with_label_values(&[version, "unknown", "unknown"]).set(1);
}
/// Update system uptime
#[cfg(feature = "enable_metrics")]
pub fn update_uptime(start_time: SystemTime) {
if let Ok(elapsed) = start_time.elapsed() {
let version = crate::VERSION.unwrap_or("unknown");
UPTIME_SECONDS.with_label_values(&[version]).set(elapsed.as_secs_f64());
}
}
/// Gather all metrics and return as Prometheus text format
#[cfg(feature = "enable_metrics")]
pub fn gather_metrics() -> Result<String, Error> {
let encoder = TextEncoder::new();
let metric_families = prometheus::gather();
let mut output = Vec::new();
if let Err(e) = encoder.encode(&metric_families, &mut output) {
return Err(Error::new(format!("Failed to encode metrics: {}", e), ""));
}
match String::from_utf8(output) {
Ok(s) => Ok(s),
Err(e) => Err(Error::new(format!("Failed to convert metrics to string: {}", e), "")),
}
}
// No-op implementations when metrics are disabled
#[cfg(not(feature = "enable_metrics"))]
#[allow(dead_code)]
pub fn increment_http_requests(_method: &str, _path: &str, _status: u16) {}
#[cfg(not(feature = "enable_metrics"))]
#[allow(dead_code)]
pub fn observe_http_request_duration(_method: &str, _path: &str, _duration_seconds: f64) {}
#[cfg(not(feature = "enable_metrics"))]
#[allow(dead_code)]
pub fn update_db_connections(_database: &str, _active: i64, _idle: i64) {}
#[cfg(not(feature = "enable_metrics"))]
pub fn increment_auth_attempts(_method: &str, _status: &str) {}
#[cfg(not(feature = "enable_metrics"))]
pub async fn update_business_metrics(_conn: &mut DbConn) -> Result<(), Error> {
Ok(())
}
#[cfg(not(feature = "enable_metrics"))]
pub fn init_build_info() {}
#[cfg(not(feature = "enable_metrics"))]
#[allow(dead_code)]
pub fn update_uptime(_start_time: SystemTime) {}
#[cfg(not(feature = "enable_metrics"))]
pub fn gather_metrics() -> Result<String, Error> {
Ok("Metrics not enabled".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "enable_metrics")]
mod metrics_enabled_tests {
use super::*;
#[test]
fn test_http_metrics_collection() {
increment_http_requests("GET", "/api/sync", 200);
increment_http_requests("POST", "/api/accounts/register", 201);
increment_http_requests("GET", "/api/sync", 500);
observe_http_request_duration("GET", "/api/sync", 0.150);
observe_http_request_duration("POST", "/api/accounts/register", 0.300);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_http_requests_total"));
assert!(metrics.contains("vaultwarden_http_request_duration_seconds"));
}
#[test]
fn test_database_metrics_collection() {
update_db_connections("sqlite", 5, 10);
update_db_connections("postgresql", 8, 2);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_db_connections_active"));
assert!(metrics.contains("vaultwarden_db_connections_idle"));
}
#[test]
fn test_authentication_metrics() {
increment_auth_attempts("password", "success");
increment_auth_attempts("password", "failed");
increment_auth_attempts("webauthn", "success");
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_auth_attempts_total"));
assert!(metrics.contains("method=\"password\""));
assert!(metrics.contains("status=\"success\""));
assert!(metrics.contains("status=\"failed\""));
}
#[test]
fn test_build_info_initialization() {
init_build_info();
let start_time = SystemTime::now();
update_uptime(start_time);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_build_info"));
assert!(metrics.contains("vaultwarden_uptime_seconds"));
}
#[test]
fn test_metrics_gathering() {
increment_http_requests("GET", "/api/sync", 200);
update_db_connections("sqlite", 1, 5);
init_build_info();
let metrics_output = gather_metrics();
assert!(metrics_output.is_ok(), "gather_metrics should succeed");
let metrics_text = metrics_output.unwrap();
assert!(!metrics_text.is_empty(), "metrics output should not be empty");
assert!(metrics_text.contains("# HELP"), "metrics should have HELP lines");
assert!(metrics_text.contains("# TYPE"), "metrics should have TYPE lines");
assert!(metrics_text.contains("vaultwarden_"), "metrics should contain vaultwarden prefix");
}
#[tokio::test]
async fn test_business_metrics_collection_noop() {
init_build_info();
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_"), "Business metrics should be accessible");
}
#[test]
fn test_path_normalization() {
increment_http_requests("GET", "/api/sync", 200);
increment_http_requests("GET", "/api/accounts/123/profile", 200);
increment_http_requests("POST", "/api/organizations/456/users", 201);
increment_http_requests("PUT", "/api/ciphers/789", 200);
let result = gather_metrics();
assert!(result.is_ok(), "gather_metrics should succeed with various paths");
let metrics_text = result.unwrap();
assert!(!metrics_text.is_empty(), "metrics output should not be empty");
assert!(metrics_text.contains("vaultwarden_http_requests_total"), "should have http request metrics");
}
#[test]
fn test_concurrent_metrics_collection() {
use std::thread;
let handles: Vec<_> = (0..10).map(|i| {
thread::spawn(move || {
increment_http_requests("GET", "/api/sync", 200);
observe_http_request_duration("GET", "/api/sync", 0.1 + (i as f64 * 0.01));
update_db_connections("sqlite", i, 10 - i);
})
}).collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let result = gather_metrics();
assert!(result.is_ok(), "metrics collection should be thread-safe");
assert!(!result.unwrap().is_empty(), "concurrent access should not corrupt metrics");
}
}
#[cfg(not(feature = "enable_metrics"))]
mod metrics_disabled_tests {
use super::*;
#[test]
fn test_no_op_implementations() {
increment_http_requests("GET", "/api/sync", 200);
observe_http_request_duration("GET", "/api/sync", 0.150);
update_db_connections("sqlite", 5, 10);
increment_auth_attempts("password", "success");
init_build_info();
let start_time = SystemTime::now();
update_uptime(start_time);
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should return ok");
assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message");
}
#[tokio::test]
async fn test_business_metrics_no_op() {
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should not panic");
assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message");
}
#[test]
fn test_concurrent_no_op_calls() {
use std::thread;
let handles: Vec<_> = (0..5).map(|i| {
thread::spawn(move || {
increment_http_requests("GET", "/test", 200);
observe_http_request_duration("GET", "/test", 0.1);
update_db_connections("test", i, 5 - i);
increment_auth_attempts("password", "success");
})
}).collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should be thread-safe");
assert_eq!(result.unwrap(), "Metrics not enabled", "disabled metrics should always return same message");
}
}
}

177
src/metrics_test.rs

@ -0,0 +1,177 @@
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use tokio::time::sleep;
#[cfg(feature = "enable_metrics")]
mod metrics_enabled_tests {
use super::*;
#[test]
fn test_http_metrics_collection() {
increment_http_requests("GET", "/api/sync", 200);
increment_http_requests("POST", "/api/accounts/register", 201);
increment_http_requests("GET", "/api/sync", 500);
observe_http_request_duration("GET", "/api/sync", 0.150);
observe_http_request_duration("POST", "/api/accounts/register", 0.300);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_http_requests_total"));
assert!(metrics.contains("vaultwarden_http_request_duration_seconds"));
}
#[test]
fn test_database_metrics_collection() {
update_db_connections("sqlite", 5, 10);
update_db_connections("postgresql", 8, 2);
observe_db_query_duration("select", 0.025);
observe_db_query_duration("insert", 0.045);
observe_db_query_duration("update", 0.030);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_db_connections_active"));
assert!(metrics.contains("vaultwarden_db_connections_idle"));
assert!(metrics.contains("vaultwarden_db_query_duration_seconds"));
}
#[test]
fn test_authentication_metrics() {
increment_auth_attempts("password", "success");
increment_auth_attempts("password", "failed");
increment_auth_attempts("webauthn", "success");
increment_auth_attempts("2fa", "failed");
update_user_sessions("authenticated", 150);
update_user_sessions("anonymous", 5);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_auth_attempts_total"));
assert!(metrics.contains("vaultwarden_user_sessions_active"));
}
#[test]
fn test_build_info_initialization() {
init_build_info();
let start_time = std::time::SystemTime::now();
update_uptime(start_time);
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_build_info"));
assert!(metrics.contains("vaultwarden_uptime_seconds"));
}
#[test]
fn test_metrics_gathering() {
increment_http_requests("GET", "/api/sync", 200);
update_db_connections("sqlite", 1, 5);
init_build_info();
let metrics_output = gather_metrics();
assert!(metrics_output.is_ok(), "gather_metrics should succeed");
let metrics_text = metrics_output.unwrap();
assert!(!metrics_text.is_empty(), "metrics output should not be empty");
assert!(metrics_text.contains("# HELP"), "metrics should have HELP lines");
assert!(metrics_text.contains("# TYPE"), "metrics should have TYPE lines");
assert!(metrics_text.contains("vaultwarden_"), "metrics should contain vaultwarden prefix");
}
#[tokio::test]
async fn test_business_metrics_collection_noop() {
// Business metrics require database access, which cannot be easily mocked in unit tests.
// This test verifies that the async function exists and can be called without panicking.
// Integration tests would provide database access and verify metrics are actually updated.
init_build_info();
let metrics = gather_metrics().expect("Failed to gather metrics");
assert!(metrics.contains("vaultwarden_"), "Business metrics should be accessible");
}
#[test]
fn test_path_normalization() {
increment_http_requests("GET", "/api/sync", 200);
increment_http_requests("GET", "/api/accounts/123/profile", 200);
increment_http_requests("POST", "/api/organizations/456/users", 201);
increment_http_requests("PUT", "/api/ciphers/789", 200);
let result = gather_metrics();
assert!(result.is_ok(), "gather_metrics should succeed with various paths");
let metrics_text = result.unwrap();
assert!(!metrics_text.is_empty(), "metrics output should not be empty");
assert!(metrics_text.contains("vaultwarden_http_requests_total"), "should have http request metrics");
}
#[test]
fn test_concurrent_metrics_collection() {
use std::thread;
let handles: Vec<_> = (0..10).map(|i| {
thread::spawn(move || {
increment_http_requests("GET", "/api/sync", 200);
observe_http_request_duration("GET", "/api/sync", 0.1 + (i as f64 * 0.01));
update_db_connections("sqlite", i, 10 - i);
})
}).collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let result = gather_metrics();
assert!(result.is_ok(), "metrics collection should be thread-safe");
assert!(!result.unwrap().is_empty(), "concurrent access should not corrupt metrics");
}
}
#[cfg(not(feature = "enable_metrics"))]
mod metrics_disabled_tests {
use super::*;
#[test]
fn test_no_op_implementations() {
increment_http_requests("GET", "/api/sync", 200);
observe_http_request_duration("GET", "/api/sync", 0.150);
update_db_connections("sqlite", 5, 10);
observe_db_query_duration("select", 0.025);
increment_auth_attempts("password", "success");
update_user_sessions("authenticated", 150);
init_build_info();
let start_time = std::time::SystemTime::now();
update_uptime(start_time);
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should return ok");
assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message");
}
#[tokio::test]
async fn test_business_metrics_no_op() {
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should not panic");
assert_eq!(result.unwrap(), "Metrics not enabled", "should return disabled message");
}
#[test]
fn test_concurrent_no_op_calls() {
use std::thread;
let handles: Vec<_> = (0..5).map(|i| {
thread::spawn(move || {
increment_http_requests("GET", "/test", 200);
observe_http_request_duration("GET", "/test", 0.1);
update_db_connections("test", i, 5 - i);
increment_auth_attempts("password", "success");
})
}).collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
let result = gather_metrics();
assert!(result.is_ok(), "disabled metrics should be thread-safe");
assert_eq!(result.unwrap(), "Metrics not enabled", "disabled metrics should always return same message");
}
}
}
Loading…
Cancel
Save