使用 Rust 写一个简易的信息网关

简介

因为毕设是一个基于 Rust 开发的 IM 分布式服务,本人之前并没有太多相关的经验,所以决定还是实现一个简易的信息服务器来练手。

服务设计

一个简单信息服务的功能性需要满足如下两点:

  1. 网关,与客户端保持长连接,用来接收/发送信息。
  2. 信息逻辑处理嘛,根据信息类型处理发送相关的逻辑。

根据这两点,我们将进行技术选型。

首先,针对长连接的需求,我们选择使用WebSocket协议。该协议的交互方式相对简单且技术成熟,非常适合我们进行实践和探索。

在用户通过协议连接到我们的服务后,我们需要能够区分不同的连接以及用户之间的关系。因此,我们需要维护一个映射表,用于关联用户ID和连接。考虑到用户数量可能较多,我们需要确保检索效率,因此我们选择使用HashMap来实现。通过用户到连接的索引,我们能够快速查找用户对应的连接并进行消息转发。

当用户的信息到达处理部分时,我们只需对信息进行合法性校验,并分析信息的目标,从而确定需要执行的操作。为此,我们需要一个健全的信息结构体来支撑这一过程。具体的实现细节将在项目的数据库设计部分中详细说明,这里仅提供一个示例以供展示:

// 媒体类型
{
  "_id": "019522a2-33ae-7d83-8083-36655138f65d",
  "author_id": "019522a2-33ae-7d83-8083-36754d94c284",
  "target_id": "019522a2-33ae-7d83-8083-368c5daea584",
  "status": "Sending",
  "message_content": {
    "medias": [
      {
        "type": "Image",
        "image_url": "http://example.com/image.png",
        "preview_url": "http://example.com/image.png",
        "width": 800,
        "height": 600
      }
    ],
    "caption": "A sample image"
  }
}

服务实现

这里本人使用 rust axum web 框架来进行网关服务构建,根据 axum 官方的 websocket example, axum 首先根据路由接收 websocket 请求,然后通过 ws handler 中的 WebSocketUpgrade 将该请求升级为WebSocket 连接,这时的连接将通过协程的方式来处理请求。

async fn new(config: &GatewayConfig, connection_manager: ConnectionManager) -> Self {
    let app = Router::new()
        .route("/chat/{token}", any(Self::ws_handler))
        .with_state(connection_manager);

    let listener =
        match TcpListener::bind(format!("{}:{}", config.listener.host, config.listener.port))
            .await
        {
            Ok(listener) => listener,
            Err(e) => {
                event!(tracing::Level::ERROR, "Failed to bind to address: {}", e);
                exit(1);
            }
        };

    Self { socket: listener, router: app }
}

async fn run(self) {
    event!(
        tracing::Level::INFO,
        "Listening on {} with protocol",
        self.socket.local_addr().unwrap(),
    );
    axum::serve(self.socket, self.router).await.unwrap();
}

// 接收 websocket 请求并将其升级为 websocket 连接
async fn ws_handler(
    ws: WebSocketUpgrade,
    State(connection_manager): State<ConnectionManager>,
    Path(token): Path<String>,
) -> impl IntoResponse {
    ws.on_upgrade(move |ws| Self::handle_socket(ws, token, connection_manager.clone()))
}

async fn handle_socket(mut ws: WebSocket, token: String, connection_manager: ConnectionManager) -> Result<()> {
  // 处理具体的信息收发
}

完善后的handle_socket 大概长成这个样子

debug!("handle_socket start");
// 初始化 ws 连接
let (mut sender, receiver) = ws.split();

// 用户鉴权和注册用户设备的逻辑
let (claims, rc) = match connection_manager.register_user_device(token).await {
    Ok(rc) => rc,
    Err(e) => {
        sender.send(Message::Text(Utf8Bytes::from(e.to_string()))).await.unwrap();
        return;
    }
};

let client = Client {
    sender: Arc::new(RwLock::new(sender)),
    user_id: claims.id,
    device_id: claims.device_id,
    platform: claims.platform,
};

event!(tracing::Level::DEBUG, "register user device: {}", client.device_id);

// 服务端向客户端发送心跳的协程
let span = tracing::info_span!("socket_tasks");

let cloned_client = client.clone();
let cloned_span = span.clone();
let mut ping_task = tokio::spawn(async move {
    Self::ping_task(cloned_client, cloned_span).await;
});

// 转发用户的信息的协程
let cloned_client = client.clone();
let cloned_span = span.clone();
let mut send_task = tokio::spawn(async move {
    Self::handle_send_task(cloned_client, rc, cloned_span).await;
});

// 接收用户发送的信息的协程
let recv_cloned_state = connection_manager.clone();
let cloned_client = client.clone();
let cloned_span = span.clone();
let mut recv_task = tokio::spawn(async move {
    Self::handle_recv_task(recv_cloned_state, cloned_client, receiver, cloned_span).await;
});

// 协程控制器,当有任意协程异常就结束整个用户会话的协程
tokio::select! {
    _ =(&mut send_task) => {
        event!(tracing::Level::DEBUG, "Send task exit");
        ping_task.abort();
        recv_task.abort();
    }
    _ = (&mut ping_task) => {
        event!(tracing::Level::DEBUG, "Ping task exit");
        recv_task.abort();
        send_task.abort();
    }
    _ = (&mut recv_task) => {
        event!(tracing::Level::DEBUG, "Recv task exit");
        ping_task.abort();
        send_task.abort();
    }
}

debug!("Start unregister device");
// 注销用户设备的逻辑
if let Err(e) = connection_manager
    .unregister_user_device(client.user_id.clone(), client.device_id.clone())
{
    event!(tracing::Level::ERROR, "Failed to unregister user device: {}", e);
}
debug!("Unregister device success");

对于 send_task、recv_task、ping_task 这三个协程的作用和逻辑如下:

  • ping_task: 每30秒向客户端发起 ping 请求,如果没有反应则退出协程让客户下线。
  • recv_task: 对用户发出的消息进行解析,如果解析失败就返回失败回执,成功就继续handle_message的逻辑。
  • send_task: 对从 mspc 管道中收到的用户信息,发送到用户的终端。
#[inline]
async fn ping_task(client: Client, span: Span) {
    let _guard = span.enter();
    loop {
        let mut attempts = 0;
        let mut success = false;

        // 发送心跳,如果失败就重试三次,三次全失败就断开连接
        while attempts < 3 {
            if let Err(e) = client.sender.write().await.send(Message::Ping(Bytes::new())).await
            {
                event!(tracing::Level::ERROR, "send ping error: {}", e);
                attempts += 1;
                tokio::time::sleep(Duration::from_secs(5)).await; // 等待五秒后重试
            } else {
                success = true;
                break;
            }
        }

        if !success {
            break; // 三次全失败,退出主循环
        }

        tokio::time::sleep(Duration::from_secs(HEART_BEAT_INTERVAL)).await;
    }
}

#[inline]
async fn handle_send_task(client: Client, mut rc: Receiver<String>, span: Span) {
    let _guard = span.enter();
    while let Some(msg) = rc.recv().await {
        event!(tracing::Level::DEBUG, "send message: {}", msg);
        if let Err(e) = client.send_text(msg).await {
            event!(tracing::Level::WARN, "Failed to send message to user: {}", e);
            break; // 如果发送失败,退出循环
        }
    }
}

#[inline]
async fn handle_recv_task(
    connection_manager: ConnectionManager,
    client: Client,
    mut receiver: SplitStream<WebSocket>,
    span: Span,
) {
    let _guard: tracing::span::Entered<'_> = span.enter();
    while let Some(recv_msg) = receiver.next().await {
        let recv_msg = match recv_msg {
            Ok(msg) => msg,
            Err(_) => {
                event!(tracing::Level::ERROR, "Failed to receive message");
                break;
            }
        };

        let mut message: DbMessage = match recv_msg {
            Message::Text(text) => match serde_json::from_slice(&text.as_bytes()) {
                Ok(msg) => msg,
                Err(e) => {
                    let error_msg = format!("Failed to deserialize message: {}", e);
                    event!(tracing::Level::ERROR, "{}", error_msg);
                    client.send_text(error_msg).await.unwrap();
                    continue;
                }
            },
            Message::Binary(binary) => match serde_json::from_slice(&binary.to_vec()) {
                Ok(msg) => msg,
                Err(e) => {
                    let error_msg = format!("Failed to deserialize message: {}", e);
                    event!(tracing::Level::ERROR, "{}", error_msg);
                    client.send_text(error_msg).await.unwrap();
                    continue;
                }
            },
            Message::Ping(_) => {
                event!(tracing::Level::DEBUG, "Received ping");
                continue;
            }
            Message::Pong(_) => {
                event!(tracing::Level::DEBUG, "Received pong");
                continue;
            }
            Message::Close(frame) => {
                if let Some(CloseFrame { code, reason }) = frame {
                    event!(tracing::Level::DEBUG, "Client disconnected: {:?}, {}", code, reason)
                } else {
                    event!(tracing::Level::DEBUG, "Client disconnected with no reason");
                }
                break;
            }
        };

        event!(tracing::Level::DEBUG, "Received message:\n{:#?}", message);

        if let Err(e) =
            MessageHandler::handle_message(connection_manager.clone(), &mut message).await
        {
            event!(tracing::Level::ERROR, "Failed to handle message: {}", e);
            client.send_text(e.to_string()).await.unwrap();
        };
    }
}
LICENSED UNDER CC BY-NC-SA 4.0