|
|
@ -1,19 +1,41 @@ |
|
|
|
use std::sync::atomic::{AtomicBool, Ordering}; |
|
|
|
use std::{ |
|
|
|
net::SocketAddr, |
|
|
|
sync::{ |
|
|
|
atomic::{AtomicBool, Ordering}, |
|
|
|
Arc, |
|
|
|
}, |
|
|
|
time::Duration, |
|
|
|
}; |
|
|
|
|
|
|
|
use rocket::serde::json::Json; |
|
|
|
use rocket::Route; |
|
|
|
use chrono::NaiveDateTime; |
|
|
|
use futures::{SinkExt, StreamExt}; |
|
|
|
use rmpv::Value; |
|
|
|
use rocket::{serde::json::Json, Route}; |
|
|
|
use serde_json::Value as JsonValue; |
|
|
|
|
|
|
|
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG}; |
|
|
|
use tokio::{ |
|
|
|
net::{TcpListener, TcpStream}, |
|
|
|
sync::mpsc::Sender, |
|
|
|
}; |
|
|
|
use tokio_tungstenite::{ |
|
|
|
accept_hdr_async, |
|
|
|
tungstenite::{handshake, Message}, |
|
|
|
}; |
|
|
|
|
|
|
|
use crate::{ |
|
|
|
api::EmptyResult, |
|
|
|
auth::Headers, |
|
|
|
db::models::{Cipher, Folder, Send, User}, |
|
|
|
Error, CONFIG, |
|
|
|
}; |
|
|
|
|
|
|
|
pub fn routes() -> Vec<Route> { |
|
|
|
routes![negotiate, websockets_err] |
|
|
|
} |
|
|
|
|
|
|
|
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true); |
|
|
|
|
|
|
|
#[get("/hub")] |
|
|
|
fn websockets_err() -> EmptyResult { |
|
|
|
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true); |
|
|
|
|
|
|
|
if CONFIG.websocket_enabled() |
|
|
|
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok() |
|
|
|
{ |
|
|
@ -55,19 +77,6 @@ fn negotiate(_headers: Headers) -> Json<JsonValue> { |
|
|
|
//
|
|
|
|
// Websockets server
|
|
|
|
//
|
|
|
|
use std::io; |
|
|
|
use std::sync::Arc; |
|
|
|
use std::thread; |
|
|
|
|
|
|
|
use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender}; |
|
|
|
|
|
|
|
use chashmap::CHashMap; |
|
|
|
use chrono::NaiveDateTime; |
|
|
|
use serde_json::from_str; |
|
|
|
|
|
|
|
use crate::db::models::{Cipher, Folder, Send, User}; |
|
|
|
|
|
|
|
use rmpv::Value; |
|
|
|
|
|
|
|
fn serialize(val: Value) -> Vec<u8> { |
|
|
|
use rmpv::encode::write_value; |
|
|
@ -118,192 +127,49 @@ fn convert_option<T: Into<Value>>(option: Option<T>) -> Value { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Server WebSocket handler
|
|
|
|
pub struct WsHandler { |
|
|
|
out: Sender, |
|
|
|
user_uuid: Option<String>, |
|
|
|
users: WebSocketUsers, |
|
|
|
} |
|
|
|
|
|
|
|
const RECORD_SEPARATOR: u8 = 0x1e; |
|
|
|
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
|
|
|
|
|
|
|
|
#[derive(Deserialize)] |
|
|
|
struct InitialMessage { |
|
|
|
protocol: String, |
|
|
|
#[derive(Deserialize, Copy, Clone, Eq, PartialEq)] |
|
|
|
struct InitialMessage<'a> { |
|
|
|
protocol: &'a str, |
|
|
|
version: i32, |
|
|
|
} |
|
|
|
|
|
|
|
const PING_MS: u64 = 15_000; |
|
|
|
const PING: Token = Token(1); |
|
|
|
|
|
|
|
const ACCESS_TOKEN_KEY: &str = "access_token="; |
|
|
|
|
|
|
|
impl WsHandler { |
|
|
|
fn err(&self, msg: &'static str) -> ws::Result<()> { |
|
|
|
self.out.close(ws::CloseCode::Invalid)?; |
|
|
|
|
|
|
|
// We need to specifically return an IO error so ws closes the connection
|
|
|
|
let io_error = io::Error::from(io::ErrorKind::InvalidData); |
|
|
|
Err(ws::Error::new(ws::ErrorKind::Io(io_error), msg)) |
|
|
|
} |
|
|
|
|
|
|
|
fn get_request_token(&self, hs: Handshake) -> Option<String> { |
|
|
|
use std::str::from_utf8; |
|
|
|
|
|
|
|
// Verify we have a token header
|
|
|
|
if let Some(header_value) = hs.request.header("Authorization") { |
|
|
|
if let Ok(converted) = from_utf8(header_value) { |
|
|
|
if let Some(token_part) = converted.split("Bearer ").nth(1) { |
|
|
|
return Some(token_part.into()); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
// Otherwise verify the query parameter value
|
|
|
|
let path = hs.request.resource(); |
|
|
|
if let Some(params) = path.split('?').nth(1) { |
|
|
|
let params_iter = params.split('&').take(1); |
|
|
|
for val in params_iter { |
|
|
|
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) { |
|
|
|
return Some(stripped.into()); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
None |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
impl Handler for WsHandler { |
|
|
|
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> { |
|
|
|
// Path == "/notifications/hub?id=<id>==&access_token=<access_token>"
|
|
|
|
//
|
|
|
|
// We don't use `id`, and as of around 2020-03-25, the official clients
|
|
|
|
// no longer seem to pass `id` (only `access_token`).
|
|
|
|
|
|
|
|
// Get user token from header or query parameter
|
|
|
|
let access_token = match self.get_request_token(hs) { |
|
|
|
Some(token) => token, |
|
|
|
_ => return self.err("Missing access token"), |
|
|
|
}; |
|
|
|
|
|
|
|
// Validate the user
|
|
|
|
use crate::auth; |
|
|
|
let claims = match auth::decode_login(access_token.as_str()) { |
|
|
|
Ok(claims) => claims, |
|
|
|
Err(_) => return self.err("Invalid access token provided"), |
|
|
|
}; |
|
|
|
|
|
|
|
// Assign the user to the handler
|
|
|
|
let user_uuid = claims.sub; |
|
|
|
self.user_uuid = Some(user_uuid.clone()); |
|
|
|
|
|
|
|
// Add the current Sender to the user list
|
|
|
|
let handler_insert = self.out.clone(); |
|
|
|
let handler_update = self.out.clone(); |
|
|
|
|
|
|
|
self.users.map.upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update)); |
|
|
|
|
|
|
|
// Schedule a ping to keep the connection alive
|
|
|
|
self.out.timeout(PING_MS, PING) |
|
|
|
} |
|
|
|
|
|
|
|
fn on_message(&mut self, msg: Message) -> ws::Result<()> { |
|
|
|
if let Message::Text(text) = msg.clone() { |
|
|
|
let json = &text[..text.len() - 1]; // Remove last char
|
|
|
|
|
|
|
|
if let Ok(InitialMessage { |
|
|
|
protocol, |
|
|
|
version, |
|
|
|
}) = from_str::<InitialMessage>(json) |
|
|
|
{ |
|
|
|
if &protocol == "messagepack" && version == 1 { |
|
|
|
return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// If it's not the initial message, just echo the message
|
|
|
|
self.out.send(msg) |
|
|
|
} |
|
|
|
|
|
|
|
fn on_timeout(&mut self, event: Token) -> ws::Result<()> { |
|
|
|
if event == PING { |
|
|
|
// send ping
|
|
|
|
self.out.send(create_ping())?; |
|
|
|
|
|
|
|
// reschedule the timeout
|
|
|
|
self.out.timeout(PING_MS, PING) |
|
|
|
} else { |
|
|
|
Ok(()) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
struct WsFactory { |
|
|
|
pub users: WebSocketUsers, |
|
|
|
} |
|
|
|
|
|
|
|
impl WsFactory { |
|
|
|
pub fn init() -> Self { |
|
|
|
WsFactory { |
|
|
|
users: WebSocketUsers { |
|
|
|
map: Arc::new(CHashMap::new()), |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
impl Factory for WsFactory { |
|
|
|
type Handler = WsHandler; |
|
|
|
|
|
|
|
fn connection_made(&mut self, out: Sender) -> Self::Handler { |
|
|
|
WsHandler { |
|
|
|
out, |
|
|
|
user_uuid: None, |
|
|
|
users: self.users.clone(), |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
fn connection_lost(&mut self, handler: Self::Handler) { |
|
|
|
// Remove handler
|
|
|
|
if let Some(user_uuid) = &handler.user_uuid { |
|
|
|
if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) { |
|
|
|
if let Some(pos) = user_conn.iter().position(|x| x == &handler.out) { |
|
|
|
user_conn.remove(pos); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
static INITIAL_MESSAGE: InitialMessage<'static> = InitialMessage { |
|
|
|
protocol: "messagepack", |
|
|
|
version: 1, |
|
|
|
}; |
|
|
|
|
|
|
|
// We attach the UUID to the sender so we can differentiate them when we need to remove them from the Vec
|
|
|
|
type UserSenders = (uuid::Uuid, Sender<Message>); |
|
|
|
#[derive(Clone)] |
|
|
|
pub struct WebSocketUsers { |
|
|
|
map: Arc<CHashMap<String, Vec<Sender>>>, |
|
|
|
map: Arc<dashmap::DashMap<String, Vec<UserSenders>>>, |
|
|
|
} |
|
|
|
|
|
|
|
impl WebSocketUsers { |
|
|
|
fn send_update(&self, user_uuid: &str, data: &[u8]) -> ws::Result<()> { |
|
|
|
if let Some(user) = self.map.get(user_uuid) { |
|
|
|
for sender in user.iter() { |
|
|
|
sender.send(data)?; |
|
|
|
async fn send_update(&self, user_uuid: &str, data: &[u8]) { |
|
|
|
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) { |
|
|
|
for (_, sender) in user.iter() { |
|
|
|
if sender.send(Message::binary(data)).await.is_err() { |
|
|
|
// TODO: Delete from map here too?
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
Ok(()) |
|
|
|
} |
|
|
|
|
|
|
|
// NOTE: The last modified date needs to be updated before calling these methods
|
|
|
|
pub fn send_user_update(&self, ut: UpdateType, user: &User) { |
|
|
|
pub async fn send_user_update(&self, ut: UpdateType, user: &User) { |
|
|
|
let data = create_update( |
|
|
|
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))], |
|
|
|
ut, |
|
|
|
); |
|
|
|
|
|
|
|
self.send_update(&user.uuid, &data).ok(); |
|
|
|
self.send_update(&user.uuid, &data).await; |
|
|
|
} |
|
|
|
|
|
|
|
pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) { |
|
|
|
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder) { |
|
|
|
let data = create_update( |
|
|
|
vec![ |
|
|
|
("Id".into(), folder.uuid.clone().into()), |
|
|
@ -313,10 +179,10 @@ impl WebSocketUsers { |
|
|
|
ut, |
|
|
|
); |
|
|
|
|
|
|
|
self.send_update(&folder.user_uuid, &data).ok(); |
|
|
|
self.send_update(&folder.user_uuid, &data).await; |
|
|
|
} |
|
|
|
|
|
|
|
pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) { |
|
|
|
pub async fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) { |
|
|
|
let user_uuid = convert_option(cipher.user_uuid.clone()); |
|
|
|
let org_uuid = convert_option(cipher.organization_uuid.clone()); |
|
|
|
|
|
|
@ -332,11 +198,11 @@ impl WebSocketUsers { |
|
|
|
); |
|
|
|
|
|
|
|
for uuid in user_uuids { |
|
|
|
self.send_update(uuid, &data).ok(); |
|
|
|
self.send_update(uuid, &data).await; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
pub fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) { |
|
|
|
pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) { |
|
|
|
let user_uuid = convert_option(send.user_uuid.clone()); |
|
|
|
|
|
|
|
let data = create_update( |
|
|
@ -349,7 +215,7 @@ impl WebSocketUsers { |
|
|
|
); |
|
|
|
|
|
|
|
for uuid in user_uuids { |
|
|
|
self.send_update(uuid, &data).ok(); |
|
|
|
self.send_update(uuid, &data).await; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@ -416,27 +282,129 @@ pub enum UpdateType { |
|
|
|
None = 100, |
|
|
|
} |
|
|
|
|
|
|
|
use rocket::State; |
|
|
|
pub type Notify<'a> = &'a State<WebSocketUsers>; |
|
|
|
pub type Notify<'a> = &'a rocket::State<WebSocketUsers>; |
|
|
|
|
|
|
|
pub fn start_notification_server() -> WebSocketUsers { |
|
|
|
let factory = WsFactory::init(); |
|
|
|
let users = factory.users.clone(); |
|
|
|
let users = WebSocketUsers { |
|
|
|
map: Arc::new(dashmap::DashMap::new()), |
|
|
|
}; |
|
|
|
|
|
|
|
if CONFIG.websocket_enabled() { |
|
|
|
thread::spawn(move || { |
|
|
|
let mut settings = ws::Settings::default(); |
|
|
|
settings.max_connections = 500; |
|
|
|
settings.queue_size = 2; |
|
|
|
settings.panic_on_internal = false; |
|
|
|
|
|
|
|
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap(); |
|
|
|
CONFIG.set_ws_shutdown_handle(ws.broadcaster()); |
|
|
|
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap(); |
|
|
|
let users2 = users.clone(); |
|
|
|
tokio::spawn(async move { |
|
|
|
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port()); |
|
|
|
let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port"); |
|
|
|
|
|
|
|
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); |
|
|
|
CONFIG.set_ws_shutdown_handle(shutdown_tx); |
|
|
|
|
|
|
|
loop { |
|
|
|
tokio::select! { |
|
|
|
Ok((stream, addr)) = listener.accept() => { |
|
|
|
tokio::spawn(handle_connection(stream, users2.clone(), addr)); |
|
|
|
} |
|
|
|
|
|
|
|
_ = &mut shutdown_rx => { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
warn!("WS Server stopped!"); |
|
|
|
info!("Shutting down WebSockets server!") |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
users |
|
|
|
} |
|
|
|
|
|
|
|
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, _addr: SocketAddr) -> Result<(), Error> { |
|
|
|
let mut user_uuid: Option<String> = None; |
|
|
|
|
|
|
|
// Accept connection, do initial handshake, validate auth token and get the user ID
|
|
|
|
use handshake::server::{Request, Response}; |
|
|
|
let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| { |
|
|
|
if let Some(token) = get_request_token(req) { |
|
|
|
if let Ok(claims) = crate::auth::decode_login(&token) { |
|
|
|
user_uuid = Some(claims.sub); |
|
|
|
return Ok(res); |
|
|
|
} |
|
|
|
} |
|
|
|
Err(Response::builder().status(401).body(None).unwrap()) |
|
|
|
}) |
|
|
|
.await?; |
|
|
|
|
|
|
|
let user_uuid = user_uuid.expect("User UUID should be set after the handshake"); |
|
|
|
|
|
|
|
// Add a channel to send messages to this client to the map
|
|
|
|
let entry_uuid = uuid::Uuid::new_v4(); |
|
|
|
let (tx, mut rx) = tokio::sync::mpsc::channel(100); |
|
|
|
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx)); |
|
|
|
|
|
|
|
let mut interval = tokio::time::interval(Duration::from_secs(15)); |
|
|
|
loop { |
|
|
|
tokio::select! { |
|
|
|
res = stream.next() => { |
|
|
|
match res { |
|
|
|
Some(Ok(message)) => { |
|
|
|
// We should receive an initial message with the protocol and version, and we will reply to it
|
|
|
|
if let Message::Text(ref message) = message { |
|
|
|
let msg = message.strip_suffix('\u{1e}').unwrap_or(message); |
|
|
|
|
|
|
|
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) { |
|
|
|
stream.send(Message::binary(INITIAL_RESPONSE)).await?; |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Just echo anything else the client sends
|
|
|
|
if stream.send(message).await.is_err() { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
_ => break, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
res = rx.recv() => { |
|
|
|
match res { |
|
|
|
Some(res) => { |
|
|
|
if stream.send(res).await.is_err() { |
|
|
|
break; |
|
|
|
} |
|
|
|
}, |
|
|
|
None => break, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
_= interval.tick() => { |
|
|
|
if stream.send(Message::Binary(create_ping())).await.is_err() { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Delete from map
|
|
|
|
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid); |
|
|
|
Ok(()) |
|
|
|
} |
|
|
|
|
|
|
|
fn get_request_token(req: &handshake::server::Request) -> Option<String> { |
|
|
|
const ACCESS_TOKEN_KEY: &str = "access_token="; |
|
|
|
|
|
|
|
if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) { |
|
|
|
if let Some(token_part) = auth.strip_prefix("Bearer ") { |
|
|
|
return Some(token_part.to_owned()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if let Some(params) = req.uri().query() { |
|
|
|
let params_iter = params.split('&').take(1); |
|
|
|
for val in params_iter { |
|
|
|
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) { |
|
|
|
return Some(stripped.to_owned()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
None |
|
|
|
} |
|
|
|