从零构建中间件:Tower 核心设计的来龙去脉
在《手把手搞懂 Service 特质:Tower 核心设计的来龙去脉》那篇内容里,我们已经搞懂了 Service 的设计初衷,以及它为什么是现在这个样子。之前我们也写过几个简单的中间件,但当时走了不少捷径。这次咱们不偷懒,完完整整地复现一遍当前 Tower 框架里 “Timeout 中间件” 的实现过程。
要写一个靠谱的中间件,得在异步 Rust 的底层层面开发 —— 这个层面会比你平时常用的层面稍深一点。不过别担心,这篇指南会把复杂的概念和逻辑讲明白,等你看完,不仅能自己写中间件,说不定还能给 Tower 生态贡献代码呢!
开始上手
我们要做的这个中间件,就是 Tower 里的tower::timeout::Timeout
。它的核心作用很简单:给内部 Service 的 “响应任务”(也就是 Future)设个最大执行时间。如果内部 Service 在规定时间内没返回结果,就直接返回一个 “超时错误”。这样客户端就不用一直等,要么重试请求,要么告诉用户出问题了。
首先,我们明确第一步:定义一个 Timeout 结构体。这个结构体要存两样东西 —— 被它包装的 “内部 Service”,以及请求的超时时长。代码如下:
use std::time::Duration;// 定义Timeout结构体:inner存被包装的内部Service,timeout存请求的超时时长
struct Timeout<S> {inner: S,timeout: Duration,
}
之前在《手把手搞懂 Service 特质:Tower 核心设计的来龙去脉》里提过一个关键点:Service 必须实现Clone
特征。为啥?因为有时候需要把Service::call
方法里的 “可变引用(&mut self)”,变成 “能转移所有权的 self”,再放进后续的 Future 里。所以,我们得给 Timeout 结构体加两个派生宏:#[derive(Debug)]
(方便调试看日志)和#[derive(Clone)]
(满足所有权转移需求):
// 派生Debug(调试时能打印结构体信息)和Clone(支持所有权转移)特征
#[derive(Debug, Clone)]
struct Timeout<S> {inner: S,timeout: Duration,
}
接下来,给 Timeout 写个 “构造函数”—— 就是一个能创建 Timeout 实例的方法:
impl<S> Timeout<S> {// 构造函数:接收“内部Service”和“超时时长”,返回Timeout实例pub fn new(inner: S, timeout: Duration) -> Self {Timeout { inner, timeout }}
}
这里有个小细节:虽然我们知道S
最终要实现Service
特征,但按照 Rust 的 API 规范,暂时不给S
加约束 —— 等后面需要的时候再加也不迟。
现在进入关键环节:给 Timeout 实现Service
特征。咱们先搭个基础框架,这个框架啥也不做,就把所有请求 “转发” 给内部 Service。先把架子立起来,后面再加超时逻辑:
use tower::Service;
use std::task::{Context, Poll};// 给Timeout<S>实现Service特征,约束:S必须是能处理Request的Service
impl<S, Request> Service<Request> for Timeout<S>
whereS: Service<Request>,
{type Response = S::Response; // 响应类型和内部Service保持一致type Error = S::Error; // 错误类型和内部Service保持一致type Future = S::Future; // 异步任务类型(Future)和内部Service保持一致// 轮询“是否就绪”:判断当前能不能接收新请求fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {// 咱们的中间件不关心“背压”(比如请求太多处理不过来),只要内部Service就绪,咱们就就绪self.inner.poll_ready(cx)}// 处理请求:把收到的请求直接传给内部Servicefn call(&mut self, request: Request) -> Self::Future {self.inner.call(request)}
}
对新手来说,先写这种 “转发框架” 很有用 —— 能帮你理清Service
特征的结构,后面加逻辑时不容易乱。
核心逻辑:怎么加超时?
要实现超时,核心思路其实很简单:
- 调用内部 Service 的
call
方法,拿到它返回的 “响应任务(Future)”; - 同时创建一个 “超时任务(Future)”—— 比如用
tokio::time::sleep
,等指定时长后就完成; - 同时盯着这两个任务:哪个先完成,就先处理哪个。如果 “超时任务” 先完成,就返回超时错误。
先试试写第一步:创建两个任务(响应任务和超时任务):
use tokio::time::sleep;fn call(&mut self, request: Request) -> Self::Future {// 1. 调用内部Service,拿到“处理请求的响应任务”let response_future = self.inner.call(request);// 2. 创建“超时任务”:等self.timeout这么久后就完成// 注意:Duration类型支持“复制”,不用clone,直接传就行let sleep = tokio::time::sleep(self.timeout);// 这里后面要写“怎么同时处理两个任务”的逻辑
}
这里有个小问题:如果直接返回 “装箱的 Future”(比如Pin<Box<dyn Future<...>>>
),会用到堆内存(Box),有额外开销。要是中间件嵌套很多层(比如 10 个、20 个),每个请求都要分配一次堆内存,性能会受影响。所以咱们得想个办法:不⽤ Box,自己定义一个 Future 类型。
自定义响应任务:ResponseFuture
咱们自己写一个ResponseFuture
结构体,专门用来 “包装两个任务”:内部 Service 的响应任务,和超时用的 sleep 任务。这个逻辑和 “用 Timeout 包装 Service” 很像,只不过这次包装的是 “Future(异步任务)”:
use tokio::time::Sleep;// 自定义的响应任务结构体:包装两个任务
pub struct ResponseFuture<F> {response_future: F, // 内部Service的“响应任务”sleep: Sleep, // 超时用的“睡眠任务”
}
这里的泛型F
,就是内部 Service 返回的 Future 类型。接下来,咱们更新 Timeout 的Service
实现 —— 把返回的 Future 类型,改成这个自定义的ResponseFuture
:
impl<S, Request> Service<Request> for Timeout<S>
whereS: Service<Request>,
{type Response = S::Response;type Error = S::Error;// 把Future类型改成自定义的ResponseFuture(用内部Service的Future当泛型参数)type Future = ResponseFuture<S::Future>;fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {self.inner.poll_ready(cx)}fn call(&mut self, request: Request) -> Self::Future {// 1. 拿到内部Service的响应任务let response_future = self.inner.call(request);// 2. 创建超时睡眠任务let sleep = tokio::time::sleep(self.timeout);// 3. 把两个任务包装成自定义的ResponseFuture,返回出去ResponseFuture {response_future,sleep,}}
}
这里要特别注意一个点:Rust 的 Future 是 “惰性的”。啥意思?就是调用inner.call(request)
时,不会立刻执行请求处理,只会返回一个 Future 对象;只有后面调用poll
(轮询)时,这个任务才会真正开始干活。
给 ResponseFuture 实现 Future 特征
要让ResponseFuture
能像普通 Future 一样被 “轮询”,就得给它实现Future
特征。咱们先搭个架子:
use std::{pin::Pin, future::Future};// 给ResponseFuture<F>实现Future特征
// 约束:F必须是返回“Result<响应, 错误>”的Future
impl<F, Response, Error> Future for ResponseFuture<F>
whereF: Future<Output = Result<Response, Error>>,
{type Output = Result<Response, Error>; // 输出类型和内部Future一致// 轮询逻辑:核心是“同时盯两个任务”fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {// 后面要写具体的轮询逻辑}
}
咱们想要的轮询逻辑很明确:
- 先看看 “响应任务”(response_future)有没有结果:有结果就直接返回;
- 如果响应任务还没好,再看看 “超时任务”(sleep)有没有完成:完成了就返回超时错误;
- 要是两个都没好,就告诉调用者 “还在等(Poll::Pending)”。
先试试写第一步:轮询响应任务。但直接写会报错:
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {// 尝试轮询响应任务——但这里会报错!match self.response_future.poll(cx) {Poll::Ready(result) => return Poll::Ready(result),Poll::Pending => {}}todo!()
}
报错原因是:self
是Pin<&mut Self>
(固定的可变引用),直接访问self.response_future
拿到的不是 “固定引用”,而调用poll
必须要Pin<&mut F>
类型。这就涉及到 Rust 里的 “Pin(固定)” 概念 —— 简单说,Pin 是为了防止某些异步任务被 “移动”,导致内存安全问题。
不过不用怕,有个叫pin-project
的库能帮我们解决这个问题。它能自动生成 “固定投影” 代码 —— 所谓 “固定投影”,就是从 “对整个结构体的固定引用”,安全地拿到 “对结构体里某个字段的固定引用”。
用 pin-project 解决固定引用问题
先给ResponseFuture
加#[pin_project]
派生宏,再给需要 “固定引用” 的字段加#[pin]
属性:
use pin_project::pin_project;// 加#[pin_project]:自动生成“固定投影”的代码
#[pin_project]
pub struct ResponseFuture<F> {#[pin] // 给response_future加#[pin]:需要固定引用response_future: F,#[pin] // 给sleep加#[pin]:也需要固定引用sleep: Sleep,
}
然后,在poll
方法里用self.project()
拿到 “带固定引用的字段”:
impl<F, Response, Error> Future for ResponseFuture<F>
whereF: Future<Output = Result<Response, Error>>,
{type Output = Result<Response, Error>;fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {// 调用project():拿到每个字段的“固定引用”(如果字段加了#[pin])let this = self.project();// this.response_future 现在是 Pin<&mut F>,能调用poll了let response_future: Pin<&mut F> = this.response_future;// this.sleep 现在是 Pin<&mut Sleep>,也能调用poll了let sleep: Pin<&mut Sleep> = this.sleep;// 后面写轮询逻辑}
}
有了固定引用,咱们就能完整实现轮询逻辑了:
impl<F, Response, Error> Future for ResponseFuture<F>
whereF: Future<Output = Result<Response, Error>>,
{type Output = Result<Response, Error>;fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {let this = self.project();// 第一步:先查响应任务有没有结果match this.response_future.poll(cx) {Poll::Ready(result) => {// 内部Service已经返回结果了,直接把结果传出去return Poll::Ready(result);}Poll::Pending => {// 响应任务还没好,继续查超时}}// 第二步:查超时任务有没有完成(也就是超时了没)match this.sleep.poll(cx) {Poll::Ready(()) => {// 超时时间到了!但这里有个问题:返回什么错误?todo!()}Poll::Pending => {// 还没超时,继续等}}// 第三步:两个任务都没好,返回“还在等”Poll::Pending}
}
现在卡在最后一个问题上:超时的时候,该返回什么类型的错误?
解决错误类型问题
目前,我们说好了 “Timeout 的错误类型和内部 Service 一致”,但内部 Service 的错误类型是泛型Error
—— 我们根本不知道这个Error
是什么,也没法创建一个 “超时错误” 的Error
实例。
咱们有三种解决方案,咱们一个个分析,最后选最适合 Tower 的方案。
方案 1:用 “装箱的错误特征对象”
就是返回Box<dyn std::error::Error + Send + Sync>
—— 简单说,不管是什么错误,都装到一个 “通用错误盒子” 里。这样不管内部 Service 返回什么错误,都能转成这个盒子类型,超时错误也能装进去。
方案 2:用枚举包两种错误
定义一个枚举,里面包含 “超时错误” 和 “内部服务错误” 两个选项:
enum TimeoutError<Error> {Timeout, // 超时错误Service(Error) // 内部服务错误
}
但这个方案有个大问题:如果中间件嵌套多层(比如 A 包装 B,B 包装 C),错误类型就会变成AError<BError<CError<MyError>>>
,写匹配逻辑时会非常麻烦,而且改中间件顺序会导致错误类型变样。
方案 3:要求内部错误能转成超时错误
定义一个TimeoutError
结构体,然后要求内部 Service 的Error
能从TimeoutError
转过来(比如TimeoutError: Into<Error>
)。但这样用户用自定义错误时,得手动写转换逻辑,很麻烦。
综合来看,方案 1 最适合 Tower—— 虽然需要一点堆内存(装箱),但胜在简单、灵活,嵌套多层也不怕。
实现方案 1:定义超时错误和通用错误类型
第一步:定义TimeoutError
结构体,实现 Rust 的标准错误特征(std::error::Error
):
use std::fmt;// 超时错误结构体:加个私有字段(()),防止外部随便创建
#[derive(Debug, Default)]
pub struct TimeoutError(());// 实现Display:错误信息的文字描述
impl fmt::Display for TimeoutError {fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {f.pad("request timed out") // 错误信息:“请求超时”}
}// 实现Error:标记这是一个标准错误类型
impl std::error::Error for TimeoutError {}
第二步:给 “通用错误盒子” 起个别名,省得每次都写一大串:
// 通用错误类型别名(Tower里已经有这个类型,叫tower::BoxError)
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
第三步:更新ResponseFuture
的Future
实现 —— 把错误类型改成BoxError
,同时要求内部 Service 的错误能转成BoxError
:
impl<F, Response, Error> Future for ResponseFuture<F>
whereF: Future<Output = Result<Response, Error>>,// 约束:内部Service的错误能转成BoxErrorError: Into<BoxError>,
{// 输出类型的错误改成BoxErrortype Output = Result<Response, BoxError>;fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {let this = self.project();// 轮询响应任务:把内部错误转成BoxErrormatch this.response_future.poll(cx) {Poll::Ready(result) => {// 用map_err把内部错误转成BoxErrorlet result = result.map_err(Into::into);return Poll::Ready(result);}Poll::Pending => {}}// 超时了:创建TimeoutError,装箱后返回match this.sleep.poll(cx) {Poll::Ready(()) => {let error = Box::new(TimeoutError(())); // 把超时错误装箱return Poll::Ready(Err(error));}Poll::Pending => {}}Poll::Pending}
}
最后,更新 Timeout 的Service
实现 —— 错误类型也要改成BoxError
,并且加上同样的约束:
impl<S, Request> Service<Request> for Timeout<S>
whereS: Service<Request>,// 和ResponseFuture保持一致:内部错误能转成BoxErrorS::Error: Into<BoxError>,
{type Response = S::Response;type Error = BoxError; // 错误类型改成BoxErrortype Future = ResponseFuture<S::Future>;fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {// 轮询就绪时,也要把内部错误转成BoxErrorself.inner.poll_ready(cx).map_err(Into::into)}fn call(&mut self, request: Request) -> Self::Future {let response_future = self.inner.call(request);let sleep = tokio::time::sleep(self.timeout);ResponseFuture {response_future,sleep,}}
}
总结
到这里,咱们就完整复现了 Tower 里 Timeout 中间件的实现!最终代码如下:
use pin_project::pin_project;
use std::time::Duration;
use std::{fmt,future::Future,pin::Pin,task::{Context, Poll},
};
use tokio::time::Sleep;
use tower::Service;// 超时中间件结构体:包装内部Service和超时时长,支持调试和克隆
#[derive(Debug, Clone)]
struct Timeout<S> {inner: S,timeout: Duration,
}impl<S> Timeout<S> {// 构造函数:接收内部Service和超时时长,返回Timeout实例fn new(inner: S, timeout: Duration) -> Self {Timeout { inner, timeout }}
}// 给Timeout<S>实现Service特征
impl<S, Request> Service<Request> for Timeout<S>
whereS: Service<Request>,S::Error: Into<BoxError>, // 约束:内部错误能转成BoxError
{type Response = S::Response;type Error = BoxError;type Future = ResponseFuture<S::Future>;// 轮询就绪状态:转发内部Service的状态,同时把错误转成BoxErrorfn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {self.inner.poll_ready(cx).map_err(Into::into)}// 处理请求:创建两个任务(响应任务+超时任务),包装成ResponseFuture返回fn call(&mut self, request: Request) -> Self::Future {let response_future = self.inner.call(request);let sleep = tokio::time::sleep(self.timeout);ResponseFuture {response_future,sleep,}}
}// 自定义响应任务:包装响应任务和超时任务,支持固定投影
#[pin_project]
struct ResponseFuture<F> {#[pin]response_future: F,#[pin]sleep: Sleep,
}// 给ResponseFuture实现Future特征
impl<F, Response, Error> Future for ResponseFuture<F>
whereF: Future<Output = Result<Response, Error>>,Error: Into<BoxError>, // 约束:内部错误能转成BoxError
{type Output = Result<Response, BoxError>;// 轮询逻辑:先查响应,再查超时,都没好就返回Pendingfn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {let this = self.project();// 先查响应任务match this.response_future.poll(cx) {Poll::Ready(result) => {let result = result.map_err(Into::into);return Poll::Ready(result);}Poll::Pending => {}}// 再查超时任务match this.sleep.poll(cx) {Poll::Ready(()) => {let error = Box::new(TimeoutError(()));return Poll::Ready(Err(error));}Poll::Pending => {}}Poll::Pending}
}// 超时错误结构体:私有字段防止外部构造,实现标准错误特征
#[derive(Debug, Default)]
struct TimeoutError(());impl fmt::Display for TimeoutError {fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {f.pad("request timed out")}
}impl std::error::Error for TimeoutError {}// 通用错误类型别名:简化“装箱错误特征对象”的写法
type BoxError = Box<dyn std::error::Error + Send + Sync>;
其实大多数 Tower 中间件,都是用这种 “包装 + 转发” 的思路实现的:
- 定义一个结构体,包装内部 Service;
- 给这个结构体实现
Service
特征,核心逻辑在call
里; - 自定义一个 Future,包装内部 Service 的 Future,实现
Future
特征处理异步逻辑。
除了 Timeout,还有几个常用的中间件也用了这个模式:
ConcurrencyLimit
:限制同时处理的最大请求数;LoadShed
:当内部 Service 忙不过来时,直接拒绝新请求(削峰);Steer
:把请求路由到不同的 Service(类似负载均衡)。
现在你已经掌握了写中间件的核心方法!如果想多练手,可以试试这几个小任务:
- 不用
tokio::time::sleep
,改用tokio::time::timeout
实现超时逻辑; - 写一个 “适配器”:用闭包修改请求、响应或错误(类似
Result::map
); - 实现
ConcurrencyLimit
(提示:需要用PollSemaphore
控制并发数)。