简介
因为毕设是一个基于 Rust 开发的 IM 分布式服务,本人之前并没有太多相关的经验,所以决定还是实现一个简易的信息服务器来练手。
服务设计
一个简单信息服务的功能性需要满足如下两点:
- 网关,与客户端保持长连接,用来接收/发送信息。
- 信息逻辑处理嘛,根据信息类型处理发送相关的逻辑。
根据这两点,我们将进行技术选型。
首先,针对长连接的需求,我们选择使用WebSocket协议。该协议的交互方式相对简单且技术成熟,非常适合我们进行实践和探索。
在用户通过协议连接到我们的服务后,我们需要能够区分不同的连接以及用户之间的关系。因此,我们需要维护一个映射表,用于关联用户ID和连接。考虑到用户数量可能较多,我们需要确保检索效率,因此我们选择使用HashMap来实现。通过用户到连接的索引,我们能够快速查找用户对应的连接并进行消息转发。
当用户的信息到达处理部分时,我们只需对信息进行合法性校验,并分析信息的目标,从而确定需要执行的操作。为此,我们需要一个健全的信息结构体来支撑这一过程。具体的实现细节将在项目的数据库设计部分中详细说明,这里仅提供一个示例以供展示:
// 媒体类型
{
"_id": "019522a2-33ae-7d83-8083-36655138f65d",
"author_id": "019522a2-33ae-7d83-8083-36754d94c284",
"target_id": "019522a2-33ae-7d83-8083-368c5daea584",
"status": "Sending",
"message_content": {
"medias": [
{
"type": "Image",
"image_url": "http://example.com/image.png",
"preview_url": "http://example.com/image.png",
"width": 800,
"height": 600
}
],
"caption": "A sample image"
}
}
服务实现
这里本人使用 rust axum web 框架来进行网关服务构建,根据 axum 官方的 websocket example, axum 首先根据路由接收 websocket 请求,然后通过 ws handler 中的 WebSocketUpgrade
将该请求升级为WebSocket 连接,这时的连接将通过协程的方式来处理请求。
async fn new(config: &GatewayConfig, connection_manager: ConnectionManager) -> Self {
let app = Router::new()
.route("/chat/{token}", any(Self::ws_handler))
.with_state(connection_manager);
let listener =
match TcpListener::bind(format!("{}:{}", config.listener.host, config.listener.port))
.await
{
Ok(listener) => listener,
Err(e) => {
event!(tracing::Level::ERROR, "Failed to bind to address: {}", e);
exit(1);
}
};
Self { socket: listener, router: app }
}
async fn run(self) {
event!(
tracing::Level::INFO,
"Listening on {} with protocol",
self.socket.local_addr().unwrap(),
);
axum::serve(self.socket, self.router).await.unwrap();
}
// 接收 websocket 请求并将其升级为 websocket 连接
async fn ws_handler(
ws: WebSocketUpgrade,
State(connection_manager): State<ConnectionManager>,
Path(token): Path<String>,
) -> impl IntoResponse {
ws.on_upgrade(move |ws| Self::handle_socket(ws, token, connection_manager.clone()))
}
async fn handle_socket(mut ws: WebSocket, token: String, connection_manager: ConnectionManager) -> Result<()> {
// 处理具体的信息收发
}
完善后的handle_socket 大概长成这个样子
debug!("handle_socket start");
// 初始化 ws 连接
let (mut sender, receiver) = ws.split();
// 用户鉴权和注册用户设备的逻辑
let (claims, rc) = match connection_manager.register_user_device(token).await {
Ok(rc) => rc,
Err(e) => {
sender.send(Message::Text(Utf8Bytes::from(e.to_string()))).await.unwrap();
return;
}
};
let client = Client {
sender: Arc::new(RwLock::new(sender)),
user_id: claims.id,
device_id: claims.device_id,
platform: claims.platform,
};
event!(tracing::Level::DEBUG, "register user device: {}", client.device_id);
// 服务端向客户端发送心跳的协程
let span = tracing::info_span!("socket_tasks");
let cloned_client = client.clone();
let cloned_span = span.clone();
let mut ping_task = tokio::spawn(async move {
Self::ping_task(cloned_client, cloned_span).await;
});
// 转发用户的信息的协程
let cloned_client = client.clone();
let cloned_span = span.clone();
let mut send_task = tokio::spawn(async move {
Self::handle_send_task(cloned_client, rc, cloned_span).await;
});
// 接收用户发送的信息的协程
let recv_cloned_state = connection_manager.clone();
let cloned_client = client.clone();
let cloned_span = span.clone();
let mut recv_task = tokio::spawn(async move {
Self::handle_recv_task(recv_cloned_state, cloned_client, receiver, cloned_span).await;
});
// 协程控制器,当有任意协程异常就结束整个用户会话的协程
tokio::select! {
_ =(&mut send_task) => {
event!(tracing::Level::DEBUG, "Send task exit");
ping_task.abort();
recv_task.abort();
}
_ = (&mut ping_task) => {
event!(tracing::Level::DEBUG, "Ping task exit");
recv_task.abort();
send_task.abort();
}
_ = (&mut recv_task) => {
event!(tracing::Level::DEBUG, "Recv task exit");
ping_task.abort();
send_task.abort();
}
}
debug!("Start unregister device");
// 注销用户设备的逻辑
if let Err(e) = connection_manager
.unregister_user_device(client.user_id.clone(), client.device_id.clone())
{
event!(tracing::Level::ERROR, "Failed to unregister user device: {}", e);
}
debug!("Unregister device success");
对于 send_task、recv_task、ping_task 这三个协程的作用和逻辑如下:
- ping_task: 每30秒向客户端发起 ping 请求,如果没有反应则退出协程让客户下线。
- recv_task: 对用户发出的消息进行解析,如果解析失败就返回失败回执,成功就继续handle_message的逻辑。
- send_task: 对从 mspc 管道中收到的用户信息,发送到用户的终端。
#[inline]
async fn ping_task(client: Client, span: Span) {
let _guard = span.enter();
loop {
let mut attempts = 0;
let mut success = false;
// 发送心跳,如果失败就重试三次,三次全失败就断开连接
while attempts < 3 {
if let Err(e) = client.sender.write().await.send(Message::Ping(Bytes::new())).await
{
event!(tracing::Level::ERROR, "send ping error: {}", e);
attempts += 1;
tokio::time::sleep(Duration::from_secs(5)).await; // 等待五秒后重试
} else {
success = true;
break;
}
}
if !success {
break; // 三次全失败,退出主循环
}
tokio::time::sleep(Duration::from_secs(HEART_BEAT_INTERVAL)).await;
}
}
#[inline]
async fn handle_send_task(client: Client, mut rc: Receiver<String>, span: Span) {
let _guard = span.enter();
while let Some(msg) = rc.recv().await {
event!(tracing::Level::DEBUG, "send message: {}", msg);
if let Err(e) = client.send_text(msg).await {
event!(tracing::Level::WARN, "Failed to send message to user: {}", e);
break; // 如果发送失败,退出循环
}
}
}
#[inline]
async fn handle_recv_task(
connection_manager: ConnectionManager,
client: Client,
mut receiver: SplitStream<WebSocket>,
span: Span,
) {
let _guard: tracing::span::Entered<'_> = span.enter();
while let Some(recv_msg) = receiver.next().await {
let recv_msg = match recv_msg {
Ok(msg) => msg,
Err(_) => {
event!(tracing::Level::ERROR, "Failed to receive message");
break;
}
};
let mut message: DbMessage = match recv_msg {
Message::Text(text) => match serde_json::from_slice(&text.as_bytes()) {
Ok(msg) => msg,
Err(e) => {
let error_msg = format!("Failed to deserialize message: {}", e);
event!(tracing::Level::ERROR, "{}", error_msg);
client.send_text(error_msg).await.unwrap();
continue;
}
},
Message::Binary(binary) => match serde_json::from_slice(&binary.to_vec()) {
Ok(msg) => msg,
Err(e) => {
let error_msg = format!("Failed to deserialize message: {}", e);
event!(tracing::Level::ERROR, "{}", error_msg);
client.send_text(error_msg).await.unwrap();
continue;
}
},
Message::Ping(_) => {
event!(tracing::Level::DEBUG, "Received ping");
continue;
}
Message::Pong(_) => {
event!(tracing::Level::DEBUG, "Received pong");
continue;
}
Message::Close(frame) => {
if let Some(CloseFrame { code, reason }) = frame {
event!(tracing::Level::DEBUG, "Client disconnected: {:?}, {}", code, reason)
} else {
event!(tracing::Level::DEBUG, "Client disconnected with no reason");
}
break;
}
};
event!(tracing::Level::DEBUG, "Received message:\n{:#?}", message);
if let Err(e) =
MessageHandler::handle_message(connection_manager.clone(), &mut message).await
{
event!(tracing::Level::ERROR, "Failed to handle message: {}", e);
client.send_text(e.to_string()).await.unwrap();
};
}
}