跳过正文

axum

··14161 字
Rust Rust-Crate
目录
rust crate - 这篇文章属于一个选集。
§ 10: 本文

axum 是基于 hyper 实现的异步 HTTP Server 库,典型使用流程:

  1. 创建 Router :通过 Route.route() 来定义 PATH 和关联的 Service,Service 一般使用 RouterMethod 来实现,如 get/post/patch() 等函数返回的 RouterMehtod 对象,这些函数的参数类型是 Handler
  2. Handler:一般由异步闭包函数实现:
    1. 输入是 Extractor ,用来从请求中提取相关信息;
    2. 返回实现 IntoResponse trait 对象(而不是 Result),axum 为 Rust 基本类型和其它类型实现了该 trait;

Router/RouterMethod/Handler 三级都可以:

  1. 通过 layer() 方法来添加中间件,从而在 Handler 处理前先做一些处理;
  2. 通过 with_state() 添加状态对象;
use axum::{Router, routing::get};

let app = Router::new()
    .route("/", get(root))
    // get() 方法处理 GET 类型请求,返回 RouterMethod 类型(实现了 Service),可以链式调用
    .route("/foo", get(get_foo).post(post_foo))
    .route("/foo/bar", get(foo_bar));

// 实现 Handler trait 的闭包
async fn root() {}
async fn get_foo() {}
async fn post_foo() {}
async fn foo_bar() {}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

1 axum::serve
#

axum::serve() 是 auxm 入口。它的第二个参数 make_service 是 Service 工厂,即 Service of Service:

  1. 第一级: Service<IncomingStream<'a>, Error = Infallible, Response = S>
  2. 第二级: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static

其中 Request 和 Response 均为 http::request::Request 和 http::response::Response 类型, 其中 Body 为 Struct axum::body::Body 类型,它实现了 http_body::Body<bytes::Bytes> trait

注意:M 没有直接使用 tower::MakeService trait 来做限界,但是效果和语义是一致的。

pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
where
    // 外层 Service:
    // Request 是 IncomingStream 类型,Response是另一个 Service
    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
    // 内层 Service:
    //   Request 是 Type Alias axum::extract::Request<T=Body>=http::request::Request<T>
    //   Response 是 Type Alias axum::response::Response<T=Body> = http::response::Response<T>;
    // Request<T=Body> 和 Response<T=Body> 中的 Body 为 Struct axum::body::Body,它实现了
    // http_body::Body<bytes::Bytes> trait.
    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
    S::Future: Send,

axum crate 的 Router/MethodRouter/Handler 类型都实现了这种 Service 工厂, 均可以作为 serve() 的参数:

// Router 实现了 Service<IncomingStream<'_>>,对应的 Response 还是 Router 类型
impl Service<IncomingStream<'_>> for Router<()>
  type Response = Router
  type Error = Infallible
  type Future = Ready<Result>::Response, <Router as Service<IncomingStream<'_>>>::Error>>

// Router 实现了 Service<Request<B>>, 对应的 Reqeust<B> 中 B 实现
// http_body::Body<Data=bytes::Bytes>, 响应为 Struct
// http::response::Response<axum::body::Body>
impl<B> Service<Request<B>> for Router<()>
where
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>
  type Response = Response<Body> // Body 为 axum::body::Body 类型
  type Error = Infallible
  type Future = RouteFuture<Infallible>

// 示例:
// Serving a Router:
use axum::{Router, routing::get};
let router = Router::new().route("/", get(|| async { "Hello, World!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

// Serving a MethodRouter:
use axum::routing::get;
let router = get(|| async { "Hello, World!" });
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();

综上:Router 实现了 Service<IncomingStream<’_>> 它的 Response 为 Router 类型,该 Router 又实现了 Service<Request<B>>,所以 Router 满足 axum::serve() 的 make_service 参数的 M 限界要求。

2 Router
#

Router<S=()> 用于定义 PATH 和处理逻辑。

  • S 对应的 是 state 类型,state 可以是实现 Clone 的任意自定义对象类型。
// S 为 Router 的 State,缺省值为 ();
pub struct Router<S = ()> { /* private fields */ }

处理逻辑有如下几种类型:

  1. route(): 使用 MethodRouter<S> 处理逻辑, 一般通过 axum::routing::method_routing module 提供的各种 HTTP Method 命令的方法来实现, 例如:

    • axum::routing::method_routing::get(handler: Handler<T, S>), Handler 一般通过 闭包 实现;
    • axum::routing::method_routing::get_service(svc: Service<Request>), 而 svc 一般通过 tower::service_fn() 闭包来实现;
  2. router_service(): 使用 tower::Service 处理逻辑, 一般通过 tower::service_fn() 闭包来实现, 或者直接复用 tower_http crate 中预定义的 Service, 如 tower_http::services::ServeFile

  3. layer(): 使用 Layer<Route> 处理逻辑, 可以使用 axum::middleware::from_fn/from_fn_with_state() 从闭包来创建, 或者直接复用 tower_http crate 中预定义的 Layer, 如 tower_http::trace::TraceLayer

综上: 各种处理逻辑类型, 如 Handler, Service, Layer 都可以使用闭包函数 来实现.

// Router<S=()> 实现的方法如下:

// 添加一个对 path 的 MethodRouter 处理逻辑
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self

// 添加一个对 path 的 Service 处理,可以使用 tower::service_fn() 从闭包函数创建, 也可
// 以使用 tower_http 的各种 middleware。
//
// 注意:Service 的 Error = Infallible, 如果 tower::service_fn() 闭包的返回值 Result
// 的 Err 不为 Infallible,则不满足该 Serivce 限界,需要使用
// axum::error_handling::HandleError 来转换。
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 添加嵌套的 Router
pub fn nest(self, path: &str, router: Router<S>) -> Self
pub fn nest_service<T>(self, path: &str, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 将多个 Router 的合并到一起
pub fn merge<R>(self, other: R) -> Self where R: Into<Router<S>>

// 为 Router 所有的 Route 都添加 layer middleware。
// layer() 获取所有权,返回一个新 Router<S>,故可以链式调用
pub fn layer<L>(self, layer: L) -> Router<S>
where
    L: Layer<Route> + Clone + Send + 'static,
    L::Service: Service<Request> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static

// 只为匹配 Route 的请求添加 layer
pub fn route_layer<L>(self, layer: L) -> Self
where
    L: Layer<Route> + Clone + Send + 'static,
    L::Service: Service<Request> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static

// 没有匹配的 Route 时 fallback 到的 handler
pub fn fallback<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static
pub fn fallback_service<T>(self, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 为 Router 提供 state,后续通过 handler 的 extract 来获得该 state,
// state 可以是实现 Clone 的任意自定义对象类型。
// 为 Router 的所有请求提供全局数据,不适合给特定单个请求提供数据(用 Extension)。
pub fn with_state<S2>(self, state: S) -> Router<S2>

// 直接调用 tower::ServiceExt 的方法会报错(隐式类型推导失败),需要调用该方法来解决。
pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S>

// 将 Route 转换为 MakeService, 它是创建另一个 Service 的 Service, 主要的使用场景是作
// 为 axum::serve 的参数。
pub fn into_make_service(self) -> IntoMakeService<Self>

Router 虽然实现了 Service<Request<B>>Service<http::request::Request<B: http_body::Body<Data=bytes.Bytes>> , 但是直接调用 tower::ServiceExt 的方法会报错,解决办法是使用 Router 的 pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S> 方法:

use axum::{
    Router,
    routing::get,
    http::Request,
    body::Body,
};
use tower::{Service, ServiceExt};

let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());

// let response = router.ready().await?.call(request).await?;
//                       ^^^^^ cannot infer type for type parameter `B`

// OK
let response = router.as_service().ready().await?.call(request).await?;

Router 提供了 into_make_service()/into_make_service_with_connect_info() 等方法来创建一个实现 MakeService trait 的类型:

// 返回一个 Service 工厂类型
pub fn into_make_service(self) -> IntoMakeService<Self> // 注意泛型参数 Self 为 Router 类型

// S 为 Router 类型时(因为是调用 Router::into_make_service() 方法),Response 也为
// Router 类型,而 Router 实现了 MakeService trait. 所以 IntoMakeService 实现了
// Service 和 MakeService trait。
impl<S, T> Service<T> for IntoMakeService<S> where S: Clone
  type Response = S
  type Error = Infallible
  type Future = IntoMakeServiceFuture<S>


// 示例
use axum::{
    routing::get,
    Router,
};
let app = Router::new().route("/", get(|| async { "Hi!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// axum::serve(listener, app).await.unwrap();
axum::serve(listener, app.into_make_service()).await.unwrap();

示例:

use axum::{
    Router,
    body::Body,
    routing::{get, delete, any_service, get_service, MethodFilter, on_service},
    extract::{Request, Path, State},
    http::StatusCode,
    error_handling::HandleErrorLayer,};

use tower::{Service, ServiceExt};
use tower::service_fn;

use tower_http::services::ServeFile;
use tower_http::trace::TraceLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;

use http::Response;
use std::{convert::Infallible, io};

let app = Router::new()
    .route("/", get(root))
    // get() 返回 MethodRouter,可以链式调用
    .route("/users", get(list_users).post(create_user))
    .route("/users/:id", get(show_user))
    .route("/api/:version/users/:id/action", delete(do_users_action))
    .route("/assets/*path", get(serve_asset));

async fn root() {}
async fn list_users() {}
async fn create_user() {}
async fn show_user(Path(id): Path<u64>) {}
// 多个 Path 参数用 tuple 表示
async fn do_users_action(Path((version, id)): Path<(String, u64)>) {}
async fn serve_asset(Path(path): Path<String>) {}

let app = Router::new()
    .route( "/", any_service(
        // 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样
        // 才匹配类型约束。否则需要使用 axum::error_handling::HandleError 来转换 Err
        // 为 Infallible。
        service_fn(|_: Request| async {
            // 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response 对象。
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        })))
    .route_service( "/foo", service_fn(|req: Request| async move {
        // 使用 axum::body::Body 类型,它实现了 http_body::Body trait
        let body = Body::from(format!("Hi from `{} /foo`", req.method()));
        let res = Response::new(body);
        Ok::<_, Infallible>(res)
    }))
    .route_service( "/static/Cargo.toml", ServeFile::new("Cargo.toml"), );

// on_service 是通用的请求函数,需要指定具体的 HTTP Method
let service = tower::service_fn(|request: Request| async {
    Ok::<_, Infallible>(Response::new(Body::empty()))});
let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));

let user_routes = Router::new() .route("/users", get(users_list)) .route("/users/:id", get(users_show));
let team_routes = Router::new() .route("/teams", get(teams_list));
let app = Router::new().merge(user_routes) .merge(team_routes);

let app = Router::new()
    .route("/foo", get(|| async {}))
    .route("/bar", get(|| async {}))
    .layer(TraceLayer::new_for_http());

let app = Router::new() .route("/foo",
    get(|| async {})) .route_layer(ValidateRequestHeaderLayer::bearer("password"));
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `GET /not-found` with a invalid token will receive `404 Not Found`

let app = Router::new()
    .route("/foo", get(|| async { /* ... */ }))
    .fallback(fallback);
async fn fallback(uri: Uri) -> (StatusCode, String) {}

#[derive(Clone)]
struct AppState {}
let routes = Router::new()
    // 使用 axum::extract::State 来为请求获得 global state
    .route("/", get(|State(state): State<AppState>| async {
        // 使用 state
    })).with_state(AppState {});

use tower::{Service, ServiceExt};
let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());
let response = router.as_service().ready().await?.call(request).await?;

3 ConnectInfo/IncomingStream
#

Router 的 into_make_service_with_connect_info<C>() 方法的主要使用场景是为 Handler 的 ConnectInfo extractor 传入对象(通过 HTTP Request Extensions 实现),这样可以获取 client 连接信息,如 socket 地址:

pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>

// 示例
use axum::{
    extract::ConnectInfo,
    routing::get,
    Router,
};
use std::net::SocketAddr; // 直接支持

let app = Router::new().route("/", get(handler));
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
    format!("Hello {addr}")
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.unwrap();

可以通过实现 Connected<IncomingStream<'_>> trait 来自定义 ConnectInfo 类型:

  • axum 为 std::net::SocketAddr 和 IncomingStream<’_> 类型实现了 Connected trait,它们可以直接使用;

Struct axum::serve::IncomingStream 是包含 local_addr 和 remote_addr 的类型。它是 MethodRouter/HandlerService 等类型在实现 Service 是的 Request 类型,如: impl Service<IncomingStream<'_>> for MethodRouter<()>

use axum::{
    extract::connect_info::{ConnectInfo, Connected},
    routing::get,
    serve::IncomingStream,
    Router,
};

let app = Router::new().route("/", get(handler));

async fn handler(ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>, ) -> String {
    format!("Hello {my_connect_info:?}")
}

#[derive(Clone, Debug)]
struct MyConnectInfo {
    // ...
}

impl Connected<IncomingStream<'_>> for MyConnectInfo {
    fn connect_info(target: IncomingStream<'_>) -> Self {
        MyConnectInfo {
            // ...
        }
    }
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();

4 body
#

Handler 需要返回实现 IntoResponse trait 的对象,而 http::Response<Body: http_body::Body> 是泛型类型,需要指定 Body 的具体类型。

axum 使用自己定义的,实现了 http_body::Body<bytes::Bytes> traitStruct axum::body::Body 类型作为 Handler 的返回值。

  • 该 Body 类型可以用作 extractor 或则 Handler 的响应。
// Struct axum::body::Body
pub struct Body(/* private fields */);

// 从实现了 http_body::Body 的对象创建 auxm::body::Body 类型对象
pub fn new<B>(body: B) -> Body
where
    B: Body<Data = Bytes> + Send + 'static, // Body 为 http_body::Body
    <B as Body>::Error: Into<Box<dyn Error + Sync + Send>>

// 其它快速创建 axum::body::Body 的方式
impl From<&'static [u8]> for Body
impl From<&'static str> for Body
impl From<()> for Body
impl From<Bytes> for Body
impl From<Cow<'static, [u8]>> for Body
impl From<Cow<'static, str>> for Body
impl From<String> for Body
impl From<Vec<u8>> for Body

// axum::body::Body 实现了 FromRequest,可以作为 extractor 来使用
impl<S> FromRequest<S> for Body where S: Send + Sync
  type Rejection = Infallible
  fn from_request<'life0, 'async_trait>(
      req: Request<Body>,
      _: &'life0 S
  ) -> Pin<Box<dyn Future<Output = Result<Body, <Body as FromRequest<S>>::Rejection>> + Send + 'async_trait>>
  where
      'life0: 'async_trait,
      Body: 'async_trait

// axum::body::Body 实现了 IntoResponse,可以作为 Handler 返回值
impl IntoResponse for Body
    fn into_response(self) -> Response<Body>


// 示例
let app = Router::new()
    .route( "/", any_service(
        // 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样
        // 才匹配类型约束。否则需要使用 axum::error_handling::HandleError 来转换 Err
        // 为 Infallible。
        service_fn(|_: Request| async {
            // 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response 对象
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        })))

5 MethodRouter
#

MethodRouter 是 Router::route(path, method_router) 方法的参数类型,为 path 提供处理逻辑。

MethodRouter<S, Infallible> 封装了请求 Method 及其 Handler 处理逻辑,可以链式调用,实现根据 Method 来进行不同的 Hander 处理。

impl<S> MethodRouter<S, Infallible> where S: Clone // S 是 State 的类型,一般由 Router<S> 传递下来。

// MethodRouter 的方法:

// on()/on_service() 是通用方法,是其它方法,如 get()/delete() 等的基础。MethodFilter
// 是枚举类型,表示标准的 HTTP Method。
pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
where
    H: Handler<T, S>,
    T: 'static,
    S: Send + Sync + 'static

pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
where
    T: Service<Request, Error = E> + Clone + Send + 'static,
    T::Response: IntoResponse + 'static,
    T::Future: Send + 'static

// 其它返回 MethodRouter 的方法,它们是 on() 方法的封装,可以链式调用
pub fn delete<H, T>(self, handler: H) -> Self
where
    H: Handler<T, S>,
    T: 'static,
    S: Send + Sync + 'static

pub fn get<H, T>(self, handler: H) -> Self
where
    H: Handler<T, S>,
    T: 'static,
    S: Send + Sync + 'static

pub fn head<H, T>(self, handler: H) -> Self
where
    H: Handler<T, S>,
    T: 'static,
    S: Send + Sync + 'static

//...

auxm::routing modle 提供了一些快捷函数,如 get()/get_service()/delete()/delete_service()/put/post() 等来快速创建对应 HTTP Method 的 MethodRouter 对象:

// Re-exports
pub use self::method_routing::any;
pub use self::method_routing::any_service;
pub use self::method_routing::delete;
pub use self::method_routing::delete_service;
pub use self::method_routing::get;
pub use self::method_routing::get_service;
pub use self::method_routing::head;
pub use self::method_routing::head_service;
pub use self::method_routing::on;
pub use self::method_routing::on_service;
pub use self::method_routing::options;
pub use self::method_routing::options_service;
pub use self::method_routing::patch;
pub use self::method_routing::patch_service;
pub use self::method_routing::post;
pub use self::method_routing::post_service;
pub use self::method_routing::put;
pub use self::method_routing::put_service;
pub use self::method_routing::trace;
pub use self::method_routing::trace_service;
pub use self::method_routing::MethodRouter; // on()/on_service() 使用的请求方法类型(关联常量)

以 get()/get_service() 为例:

  • get(): 使用 Handler 处理逻辑;
  • get_service(): 使用 Service 处理逻辑;
pub fn get<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
where
    H: Handler<T, S>,
    T: 'static,
    S: Clone + Send + Sync + 'static

pub fn get_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
where
    T: Service<Request> + Clone + Send + 'static,
    T::Response: IntoResponse + 'static,
    T::Future: Send + 'static,
    S: Clone,

// 示例:
use axum::{
    routing::get,
    routing::get_service,
    extract::Request,
    body::Body,
    Router,
};

async fn handler() {}
async fn other_handler() {}

let service = tower::service_fn(|request: Request| async {
    // 返回值 Result 的 Err 必须是 Infallible
    Ok::<_, Infallible>(Response::new(Body::empty()))
});

let app = Router::new()
    .route("/", get(handler))
    .route("/svc", get_service(service).on(MethodFilter::DELETE, other_handler));

MethodRouter 支持添加 state 和 layer,但只对该 MethodRouter 的 Handler 有效:

// 关联 State
pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E>
// 关联 Layer,Layer 在 Handler 之前处理
pub fn route_layer<L>(self, layer: L) -> MethodRouter<S, E>
where
    L: Layer<Route<E>> + Clone + Send + 'static,
    L::Service: Service<Request, Error = E> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static,
    E: 'static,
    S: 'static,

// 示例
use axum::{ routing::get, Router, };
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new().route(
    "/foo",
    get(|| async {}).route_layer(ValidateRequestHeaderLayer::bearer("password"))
);
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `POST /FOO` with a invalid token will receive `405 Method Not Allowed`

MethodRouter 是否实现 tower::Service<Request> trait, 取决于它的 State 情况 ,对于非 () 的情况,需要使用 with_state() 传入对应类型的 state 值后,该 MethodRouter 才实现 Service:

use tower::Service;
use axum::{routing::get, extract::{State, Request}, body::Body};

// this `MethodRouter` doesn't require any state, i.e. the state is `()`,
let method_router = get(|| async {});
// and thus it implements `Service`
assert_service(method_router);

// this requires a `String` and doesn't implement `Service`
let method_router = get(|_: State<String>| async {});

// until you provide the `String` with `.with_state(...)`
let method_router_with_state = method_router.with_state(String::new());

// and then it implements `Service`
assert_service(method_router_with_state);

// helper to check that a value implements `Service`
fn assert_service<S>(service: S) where S: Service<Request>, {}

6 Handler
#

Handler traitMethodRouter 各方法 on/get/delete/put/post() 使用的处理逻辑,一般由 async 闭包函数实现。

pub trait Handler<T, S>: Clone + Send + Sized + 'static {
    type Future: Future<Output = Response> + Send + 'static;

    // Required method
    fn call(self, req: Request, state: S) -> Self::Future;

    // Provided methods,返回的 Layered 也实现了 Handler
    fn layer<L>(self, layer: L) -> Layered<L, Self, T, S>
       where L: Layer<HandlerService<Self, T, S>> + Clone,
             L::Service: Service<Request> { ... }

    fn with_state(self, state: S) -> HandlerService<Self, T, S> { ... }
}

// An adapter that makes a Handler into a Service.
// Created with Handler::with_state or HandlerWithoutStateExt::into_service.
pub struct HandlerService<H, T, S> { /* private fields */ }

impl<H, T, S> HandlerService<H, T, S>
    pub fn state(&self) -> &S
    pub fn into_make_service(self) -> IntoMakeService<HandlerService<H, T, S>>
    pub fn into_make_service_with_connect_info<C>(self)->IntoMakeServiceWithConnectInfo<HandlerService<H,T,S>, C>

// HandlerService 也是 Service of Service 工厂, 可以直接给 axum::serve() 使用。
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
where
    H: Clone,
    S: Clone
type Response = HandlerService<H, T, S> // 返回自身类型
type Error = Infallible

impl<H, T, S, B> Service<Request<B>> for HandlerService<H, T, S>
where
    H: Handler<T, S> + Clone + Send + 'static,
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>,
    S: Clone + Send + Sync
type Response = Response<Body>
type Error = Infallible

// layer() 返回的 Layerd 实现了 Handler
impl<H, S, T, L> Handler<T, S> for Layered<L, H, T, S>
where
    L: Layer<HandlerService<H, T, S>> + Clone + Send + 'static,
    H: Handler<T, S>,
    L::Service: Service<Request, Error = Infallible> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse,
    <L::Service as Service<Request>>::Future: Send,
    T: 'static,
    S: 'static

// MethodRouter 也实现了 Handler
impl<S> Handler<(), S> for MethodRouter<S> where S: Clone + 'static

// 示例:
// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();

axum 默认为 16 个参数内的 async FnOnce 闭包函数实现了 Handler trait:

  • 闭包的输入参数是 extractor,可以有多个,但是前面的参数必须实现 FromRequestParts ,最后一个参数实现 FromRequest
  • 闭包返回的 Output 必须实现 IntoResponse + Clone + Send + 'static , 所以一般不包含借用, 而是使用 move 将所有权转移到闭包中;

这些 async FnOnce 函数或闭包返回的结果是 Future<Output = IntoResponse> ,并不是包含 Error 的 Result 类型, 所以 Handler 没有提供出错返回的机制 。如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse。

// 一些实现 Handler 的闭包函数示例
impl<F, Fut, Res, S> Handler<((),), S> for F
where
    F: FnOnce() -> Fut + Clone + Send + 'static, // 闭包函数
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse

impl<F, Fut, S, Res, M, T1> Handler<(M, T1), S> for F
where
    F: FnOnce(T1) -> Fut + Clone + Send + 'static, // 闭包函数
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequest<S, M> + Send

impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
    F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequestParts<S> + Send,
    T2: FromRequest<S, M> + Send

// 一直到 16 个输入参数的闭包类型
impl<F, Fut, S, Res, M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16> Handler<(M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16), S> for F
where
    F: FnOnce(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequestParts<S> + Send,
    T2: FromRequestParts<S> + Send,
    T3: FromRequestParts<S> + Send,
    T4: FromRequestParts<S> + Send,
    T5: FromRequestParts<S> + Send,
    T6: FromRequestParts<S> + Send,
    T7: FromRequestParts<S> + Send,
    T8: FromRequestParts<S> + Send,
    T9: FromRequestParts<S> + Send,
    T10: FromRequestParts<S> + Send,
    T11: FromRequestParts<S> + Send,
    T12: FromRequestParts<S> + Send,
    T13: FromRequestParts<S> + Send,
    T14: FromRequestParts<S> + Send,
    T15: FromRequestParts<S> + Send,
    T16: FromRequest<S, M> + Send

闭包实现的 async Handler 示例:

use axum::{body::Bytes, http::StatusCode};

// 空返回值表示返回 200 OK,body 为空
async fn unit_handler() {}

// 返回 String 表示返回 200 OK 的同时返回纯文本 body
async fn string_handler() -> String {
    "Hello, World!".to_string()
}

// Bytes 实现了 FromRequest,故是 extractor 类型,获得整个 body 内容。
// String/StatusCode 都实现了 IntoResponse, 故 Result<String, StatusCode> 也实现了 IntoResponse.
async fn echo(body: Bytes) -> Result<String, StatusCode> {
    if let Ok(string) = String::from_utf8(body.to_vec()) {
        Ok(string)
    } else {
        Err(StatusCode::BAD_REQUEST)
    }
}

Handler 也提供了 layer() 和 with_state() 方法,用来为该 Handler 添加中间件和状态:

  • State 需要满足 Clone + Send + Sync + 'static ,一般需要获得 State 的所有权(如 Arc)才能满足 ‘static 要求。
// layer() 示例
use axum::{routing::get, handler::Handler, Router, };
use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit};

async fn handler() { /* ... */ }
let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64));
let app = Router::new().route("/", get(layered_handler));

// with_state() 示例
use axum::{
    handler::Handler,
    response::IntoResponse,
    extract::{ConnectInfo, State},
};
use std::net::SocketAddr;

// State 需要实现 Clone
#[derive(Clone)]
struct AppState {};

async fn handler(
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
    State(state): State<AppState>, // 提取出 State
) -> String {
    format!("Hello {addr}")
}

let app = handler.with_state(AppState {}); // 为 handler 关联 State
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>(),).await.unwrap();

7 IntoResponse
#

Handler 闭包函数的返回值需要实现 IntoResponse:

impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
    F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static, // 闭包函数
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse, // 需要实现 IntoResponse
    T1: FromRequestParts<S> + Send,
    T2: FromRequest<S, M> + Send

axum 默认为常用类型实现了 IntoResponse trait:

pub trait IntoResponse {
    // Required method
    fn into_response(self) -> Response<Body>; // Body 为 struct axum::body::Body
}

impl<const N: usize> IntoResponse for &'static [u8; N]
impl IntoResponse for &'static [u8]
impl<const N: usize> IntoResponse for [u8; N]
impl IntoResponse for &'static str
impl IntoResponse for &'static [u8]
impl IntoResponse for Cow<'static, str>
impl IntoResponse for Cow<'static, [u8]>
impl IntoResponse for Infallible
impl IntoResponse for ()
impl IntoResponse for Box<str>
impl IntoResponse for Box<[u8]>
impl IntoResponse for String
impl IntoResponse for Vec<u8>
impl IntoResponse for Bytes
impl IntoResponse for BytesMut
impl IntoResponse for Extensions
impl IntoResponse for HeaderMap
impl IntoResponse for Parts
impl IntoResponse for StatusCode

// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse

// struct axum::response::ErrorResponse 也实现了 From<IntoResponse>
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse

impl<T, U> IntoResponse for Chain<T, U> where
    T: Buf + Unpin + Send + 'static,
    U: Buf + Unpin + Send + 'static,

// http::response::Response 实现 IntoResponse
impl<B> IntoResponse for Response<B>
where
    B: Body<Data = Bytes> + Send + 'static,
    <B as Body>::Error: Into<Box<dyn Error + Send + Sync>>

// (K, V) 数组实现 IntoResponse, K 可转换为 HeaderName, V 可转换为 HeaderValue
impl<K, V, const N: usize> IntoResponse for [(K, V); N]
where
    K: TryInto<HeaderName>,
    <K as TryInto<HeaderName>>::Error: Display,
    V: TryInto<HeaderValue>,
    <V as TryInto<HeaderValue>>::Error: Display

// 特殊的 tuple 类型实现 IntoResponse , T1..T16 都是可选参数列表
impl<R> IntoResponse for (Parts, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Parts, T1, R) where T1: IntoResponseParts, R: IntoResponse

impl<R> IntoResponse for (Response<()>, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Response<()>, T1, R) where T1: IntoResponseParts, R: IntoResponse

impl<R> IntoResponse for (StatusCode, R) where R: IntoResponse
impl<R, T1> IntoResponse for (StatusCode, T1, R) where T1: IntoResponseParts, R: IntoResponse

impl<R, T1> IntoResponse for (T1, R) where T1: IntoResponseParts, R: IntoResponse
impl<R> IntoResponse for (R,) where R: IntoResponse

其它实现 IntoResponse 的类型,如 Json/Form/Html/Body 等:

// Extractor trait 内部的 Rejection 关联类型
impl IntoResponse for MultipartRejection
impl IntoResponse for BytesRejection
impl IntoResponse for ExtensionRejection
impl IntoResponse for FailedToBufferBody
// ...

// 其它 axum 定义的类型
impl IntoResponse for Body // struct axum::body::Body
impl IntoResponse for Redirect
impl<T> IntoResponse for Extension<T> where T: Clone + Send + Sync + 'static,
impl<T> IntoResponse for Form<T> where T: Serialize
impl<T> IntoResponse for Json<T> where T: Serialize // Struct axum::Json
impl<T> IntoResponse for Html<T> where T: Into<Body>

虽然文档没提, 任何实现 IntoResponse 的类型也实现了 Handler ,可以作为 MethodRouter 的各方法 on/get/delete/put/post() 的返回值:

mod private {
    // Marker type for `impl<T: IntoResponse> Handler for T`
    #[allow(missing_debug_implementations)]
    pub enum IntoResponseHandler {}
}

impl<T, S> Handler<private::IntoResponseHandler, S> for T
where
    T: IntoResponse + Clone + Send + 'static,
{
    type Future = std::future::Ready<Response>;

    fn call(self, _req: Request, _state: S) -> Self::Future {
        std::future::ready(self.into_response())
    }
}

// 示例: 使用 tuple 来作为 Handler:
use axum::{
    Router,
    routing::{get, post},
    Json,
    http::StatusCode,
};
use serde_json::json;

let app = Router::new()
    .route("/", get("Hello, World!")) // &str 实现了 IntoResponse
    .route("/users", post(
        // tuple 的成员都实现了 IntoResponse,所以 tuple 也实现了 IntoResponse
        (StatusCode::CREATED, Json(json!({ "id": 1, "username": "alice" })),)));

8 extractor
#

extractor 是实现了 FromRequest 或 FromRequestParts trait 的类型值,它们作为 Handler 闭包函数的输入参数,用于从请求中提取相关信息供 Handler 闭包函数使用。

对于有多个 extractor 输入参数的 Handler 闭包函数,如 FnOnce(T1, T2, T3) ,前面的参数必须实现 FromRequestParts trait ,最后一个参数实现 FromRequest trait 。这是因为 FromRequestParts 不消耗 body,而 FromRequest 消耗 body 而且只能消耗一次。

use axum::{
    // Request/Json/Path/Extension/Query 均是 extractor
    extract::{Request, Json, Path, Extension, Query},
    routing::post,
    http::header::HeaderMap,
    body::{Bytes, Body},
    Router,
};
use serde_json::Value;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Deserialize)]
struct CreateUser {
    email: String,
    password: String,
}

// 函数传参本质上是模式匹配赋值,所以类似于 Path((user_id, user_name)) 的
// user_id/user_name 是解构后的内容。

// Path 从请求中提取路径字段(多个字段用 tuple 表示)
async fn path(Path(user_id): Path<u32>) {}
async fn path(Path((user_id, user_name)): Path<(u32, String)>) {}

// Query 从请求参数中生成对应类型
async fn query(Query(params): Query<HashMap<String, String>>) {}

// HeaderMap 包含所有请求 HTTP Headers
async fn headers(headers: HeaderMap) {}

// String 包含请求 body 的内容,确保是有效的 UTF-8
async fn string(body: String) {}

// Bytes 包含 raw 请求 Body 的内容
async fn bytes(body: Bytes) {}

// Json 将请求 body 反序列化为对应类型值(通用的为 serde_json::Value 类型)
async fn json(Json(payload): Json<Value>) {}

// Json 既实现了 Extractor 的 FromRequest trait,也实现了 IntoResponse,所以可以作为
// Handler 返回值。
async fn create_user(Json(payload): Json<CreateUser>) {}

// Request 返回整个请求类型
async fn request(request: Request) {}

// Extension 从 request extensions 中提取数据,通用用于提取共享的 state
async fn extension(Extension(state): Extension<State>) {}

#[derive(Clone)]
struct State { /* ... */ }
let app = Router::new()
    .route("/path/:user_id/:user_name", post(path))
    .route("/query", post(query))
    .route("/string", post(string))
    .route("/bytes", post(bytes))
    .route("/json", post(json))
    .route("/request", post(request))
    .route("/extension", post(extension)
    .route("/users", post(create_user));

Handler 闭包函数的各 extractor 函数参数默认都是必须的(否则报错),可以使用 Option 来指定可选参数:

use axum::{
    extract::{Path, Query},
    routing::get,
    Router,
};
use uuid::Uuid;
use serde::Deserialize;

let app = Router::new().route("/users/:id/things", get(get_user_things));

#[derive(Deserialize)]
struct Pagination {
    page: usize,
    per_page: usize,
}

impl Default for Pagination {
    fn default() -> Self {
        Self { page: 1, per_page: 30 }
    }
}

// Handler 可以同时使用多个 extractors
async fn get_user_things(
    Path(user_id): Path<Uuid>, // 必须的参数
    pagination: Option<Query<Pagination>>, // 可选的参数:从请求参数中构造(反序列化)为 Pagination 类型对象
) {
    let Query(pagination) = pagination.unwrap_or_default();
    // ...
}

使用 Result 获得 extractor 出错原因,错误类型是实现 extractor 的 FromRequestParts 或 FromRequest 时指定的类型 Rejection 关联类型:

use axum::{
    extract::{Json, rejection::JsonRejection},
    routing::post,
    Router,
};
use serde_json::Value;

// 每种 Extractor 都定义了自己的 Rejection 类型,返回的 Result::Err 为对应类型值。
async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
    match payload {
        Ok(payload) => {
            // We got a valid JSON payload
        }
        Err(JsonRejection::MissingJsonContentType(_)) => {
            // Request didn't have `Content-Type: application/json` header
        }
        Err(JsonRejection::JsonDataError(_)) => {
            // Couldn't deserialize the body into the target type
        }
        Err(JsonRejection::JsonSyntaxError(_)) => {
            // Syntax error in the body
        }
        Err(JsonRejection::BytesRejection(_)) => {
            // Failed to extract the request body
        }
        Err(_) => {
            // `JsonRejection` is marked `#[non_exhaustive]` so match must include a
            // catch-all case.
        }
    }
}

let app = Router::new().route("/users", post(create_user));

更复杂的获取 extractor 出错信息的例子:

use std::error::Error;
use axum::{
    extract::{Json, rejection::JsonRejection},
    response::IntoResponse,
    http::StatusCode,
};
use serde_json::{json, Value};

async fn handler(result: Result<Json<Value>, JsonRejection>,) -> Result<Json<Value>, (StatusCode, String)> {
    match result {
        // if the client sent valid JSON then we're good
        Ok(Json(payload)) => Ok(Json(json!({ "payload": payload }))),

        Err(err) => match err {
            JsonRejection::JsonDataError(err) => {
                Err(serde_json_error_response(err))
            }
            JsonRejection::JsonSyntaxError(err) => {
                Err(serde_json_error_response(err))
            }
            // handle other rejections from the `Json` extractor
            JsonRejection::MissingJsonContentType(_) => Err((
                StatusCode::BAD_REQUEST,
                "Missing `Content-Type: application/json` header".to_string(),
            )),
            JsonRejection::BytesRejection(_) => Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                "Failed to buffer request body".to_string(),
            )),
            // we must provide a catch-all case since `JsonRejection` is marked
            // `#[non_exhaustive]`
            _ => Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                "Unknown error".to_string(),
            )),
        },
    }
}

// attempt to extract the inner `serde_path_to_error::Error<serde_json::Error>`, if
// that succeeds we can provide a more specific error.
//
// `Json` uses `serde_path_to_error` so the error will be wrapped in
// `serde_path_to_error::Error`.
fn serde_json_error_response<E>(err: E) -> (StatusCode, String)
where
    E: Error + 'static,
{
    if let Some(err) = find_error_source::<serde_path_to_error::Error<serde_json::Error>>(&err) {
        let serde_json_err = err.inner();
        (
            StatusCode::BAD_REQUEST,
            format!(
                "Invalid JSON at line {} column {}",
                serde_json_err.line(),
                serde_json_err.column()
            ),
        )
    } else {
        (StatusCode::BAD_REQUEST, "Unknown error".to_string())
    }
}

// attempt to downcast `err` into a `T` and if that fails recursively try and
// downcast `err`'s source
fn find_error_source<'a, T>(err: &'a (dyn Error + 'static)) -> Option<&'a T>
where
    T: Error + 'static,
{
    if let Some(err) = err.downcast_ref::<T>() {
        Some(err)
    } else if let Some(source) = err.source() {
        find_error_source(source)
    } else {
        None
    }
}

自定义 extractor:实现 FromRequestParts trait

use axum::{
    async_trait,
    extract::FromRequestParts,
    routing::get,
    Router,
    http::{
        StatusCode,
        header::{HeaderValue, USER_AGENT},
        request::Parts,
    },
};

// Extractor 惯例是 struct tuple 类型(一般是泛型类型)
struct ExtractUserAgent(HeaderValue);

// S 为 State 类型
#[async_trait]
impl<S> FromRequestParts<S> for ExtractUserAgent where S: Send + Sync,
{
    // 如果 extract 失败,返回的错误值类型。
    type Rejection = (StatusCode, &'static str);

    // 函数返回 Result,包含错误拒绝的值。
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        if let Some(user_agent) = parts.headers.get(USER_AGENT) {
            Ok(ExtractUserAgent(user_agent.clone()))
        } else {
            Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
        }
    }
}

// 函数传参本质上是模式匹配赋值,所以 user_agent 包含解构后的内容。
async fn handler(ExtractUserAgent(user_agent): ExtractUserAgent) {
    // 使用 user_agent 值
}

let app = Router::new().route("/foo", get(handler));

自定义 extractor:实现 FromRequest:

use axum::{
    async_trait,
    extract::{Request, FromRequest},
    response::{Response, IntoResponse},
    body::{Bytes, Body},
    routing::get,
    Router,
    http::{
        StatusCode,
        header::{HeaderValue, USER_AGENT},
    },
};

struct ValidatedBody(Bytes);

#[async_trait]
impl<S> FromRequest<S> for ValidatedBody
where
    Bytes: FromRequest<S>,
    S: Send + Sync,
{
    type Rejection = Response;

    // 提取出错时返回 Rejection 类型
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let body = Bytes::from_request(req, state)
            .await
            .map_err(IntoResponse::into_response)?;

        // do validation...

        Ok(Self(body))
    }
}

async fn handler(ValidatedBody(body): ValidatedBody) {
    // 使用 body 数据
}

let app = Router::new().route("/foo", get(handler));

通过实现 FromRequestParts 和 FromRequest trait,可以自定义 extractor,但是一个自定义类型只能实现 其中一个 trait ,除非该 extractor 是其它 extractor 的包装器,这是通过对自定义类型的限界来实现的。

use axum::{
    Router,
    body::Body,
    routing::get,
    extract::{Request, FromRequest, FromRequestParts},
    http::{HeaderMap, request::Parts},
    async_trait,
};
use std::time::{Instant, Duration};

// an extractor that wraps another and measures how long time it takes to run
struct Timing<E> {
    extractor: E,
    duration: Duration,
}

// we must implement both `FromRequestParts`
#[async_trait]
impl<S, T> FromRequestParts<S> for Timing<T>
where
    S: Send + Sync,
    T: FromRequestParts<S>,
{
    type Rejection = T::Rejection;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let start = Instant::now();
        let extractor = T::from_request_parts(parts, state).await?;
        let duration = start.elapsed();
        Ok(Timing {
            extractor,
            duration,
        })
    }
}

// and `FromRequest`
#[async_trait]
impl<S, T> FromRequest<S> for Timing<T>
where
    S: Send + Sync,
    T: FromRequest<S>,
{
    type Rejection = T::Rejection;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let start = Instant::now();
        let extractor = T::from_request(req, state).await?;
        let duration = start.elapsed();
        Ok(Timing {
            extractor,
            duration,
        })
    }
}

async fn handler(
    // this uses the `FromRequestParts` impl
    _: Timing<HeaderMap>,
    // this uses the `FromRequest` impl
    _: Timing<String>,
) {}

extractor 提取请求 body 时,最大为 2MB 限制,可使用 DefaultBodyLimit::max(size) 来自定义:

use axum::{
    Router,
    routing::post,
    body::Body,
    extract::{Request, DefaultBodyLimit},
};

let app = Router::new()
    .route("/", post(|request: Request| async {}))
    // change the default limit
    .layer(DefaultBodyLimit::max(1024));

如果要自动记录 extraction rejection 的 log,需要开启 axum 的 tracing feature(默认开启)和设置环境变量:

RUST_LOG=info,axum::rejection=trace

axum::extract module 提供了一些常用的 extractor 类型:

  • JSON
  • Form
  • Request
  • HeaderMap
  • Extension:可以用于提取 with_state() 传入的 State 类型值。
  • ConnectInfo
  • Host
  • MatchedPath
  • MultiPart
  • NestedPath
  • OriginalUrl
  • Path
  • Query
  • RawForm
  • RawPathParams
  • RawQuery
  • State

JSON:实现了 FromRequest 和 IntoResponse trait,可以作为 Handler 的输入和输出类型:

pub struct Json<T>(pub T);

// Extractor example
use axum::{
    extract,
    routing::post,
    Router,
};
use serde::Deserialize;

#[derive(Deserialize)]
struct CreateUser {
    email: String,
    password: String,
}

async fn create_user(extract::Json(payload): extract::Json<CreateUser>) {
    // payload is a `CreateUser`
}

let app = Router::new().route("/users", post(create_user));

// Response example
use axum::{
    extract::Path,
    routing::get,
    Router,
    Json,
};
use serde::Serialize;
use uuid::Uuid;

#[derive(Serialize)]
struct User {
    id: Uuid,
    username: String,
}

async fn get_user(Path(user_id) : Path<Uuid>) -> Json<User> {
    let user = find_user(user_id).await;
    Json(user)
}

async fn find_user(user_id: Uuid) -> User {
    // ...
}

let app = Router::new().route("/users/:id", get(get_user));

Form:实现了 FromRequest 和 IntoResponse trait,可以作为 Handler 的输入和输出类型:

pub struct Form<T>(pub T);  // 实现了 FromRequest

use axum::Form;
use serde::Deserialize;

#[derive(Deserialize)]
struct SignUp {
    username: String,
    password: String,
}

async fn accept_form(Form(sign_up): Form<SignUp>) {
    // ...
}


// Response
use axum::Form;
use serde::Serialize;

#[derive(Serialize)]
struct Payload {
    value: String,
}

async fn handler() -> Form<Payload> {
    Form(Payload { value: "foo".to_owned() })
}

Request:Request 返回整个 Request 对象,具有最大化的控制能力:

async fn request(request: Request) {}

HeaderMap:包含所有的 Header

async fn headers(headers: HeaderMap) {}

Extension 从 http request extension 向 Handler 传递 state 的机制。

  • Extension 实现了 FromRequestParts 和 Layer<S> 和 IntoResponse, 所以可以作为 extractor、layer middleware 和响应数据类型;

Extension 的主要使用场景是,开发 layer middleware 时使用 http request extension 来向 handler 传递数据

async fn extension(Extension(state): Extension<State>) {}

// 作为 extractor:常用于 handers 间的共享 state 传递
use axum::{
    Router,
    Extension,
    routing::get,
};
use std::sync::Arc;

// Some shared state used throughout our application
struct State {
    // ...
}
async fn handler(state: Extension<Arc<State>>) {
    // ...
}
let state = Arc::new(State { /* ... */ });
let app = Router::new().route("/", get(handler))
    .layer(Extension(state)); // Router 级别,适用于它的所有 Handler


// 作为响应
use axum::{
    Extension,
    response::IntoResponse,
};
async fn handler() -> (Extension<Foo>, &'static str) {
    (
        Extension(Foo("foo")),
        "Hello, World!"
    )
}
#[derive(Clone)]
struct Foo(&'static str);

// Passing state from middleware to handlers
// State can be passed from middleware to handlers using request extensions:
use axum::{
    Router,
    http::StatusCode,
    routing::get,
    response::{IntoResponse, Response},
    middleware::{self, Next},
    extract::{Request, Extension},
};

#[derive(Clone)]
struct CurrentUser { /* ... */ }

async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
    let auth_header = req.headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    let auth_header = if let Some(auth_header) = auth_header {
        auth_header
    } else {
        return Err(StatusCode::UNAUTHORIZED);
    };

    if let Some(current_user) = authorize_current_user(auth_header).await {
        // insert the current user into a request extension so the handler can extract it
        req.extensions_mut().insert(current_user);
        Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

async fn authorize_current_user(auth_token: &str) -> Option<CurrentUser> {
    // ...
}

async fn handler(
    // extract the current user, set by the middleware
    Extension(current_user): Extension<CurrentUser>,
) {
    // ...
}

let app = Router::new()
    .route("/", get(handler))
    .route_layer(middleware::from_fn(auth));

ConnectInfo:提取 client 请求信息,如 client 地址。需要和 Router.into_make_service_with_connect_info() 连用。通过实现 Connected<IncomingStream<’_>> trait,也可以自定义 ConnectInfo 的值。

use axum::{
    extract::connect_info::{ConnectInfo, Connected},
    routing::get,
    serve::IncomingStream,
    Router,
};

let app = Router::new().route("/", get(handler));

async fn handler(
    ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
) -> String {
    format!("Hello {my_connect_info:?}")
}

// 通过实现 Connected<IncomingStream<'_>> trait,也可以自定义 ConnectInfo 的值。
#[derive(Clone, Debug)]
struct MyConnectInfo {
    // ...
}

impl Connected<IncomingStream<'_>> for MyConnectInfo {
    fn connect_info(target: IncomingStream<'_>) -> Self {
        MyConnectInfo {
            // ...
        }
    }
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();

DefaultBodyLimit: Layer for configuring the default request body limit.

  • 其实并不是 extractor

      use axum::{
          Router,
          routing::post,
          body::Body,
          extract::{Request, DefaultBodyLimit},
      };
    
      let app = Router::new()
          // change the default limit
          .layer(DefaultBodyLimit::max(1024))
          // this route has a different limit
          .route("/", post(|request: Request| async {}).layer(DefaultBodyLimit::max(1024)))
          // this route still has the default limit
          .route("/foo", post(|request: Request| async {}));
    

    Host:Extractor that resolves the hostname of the request. 实现了 FromRequestParts

    Hostname is resolved through the following, in order: Forwarded header X-Forwarded-Host header Host header request target / URI

MatchedPath: Access the path in the router that matches the request. 返回的 path 为 Router 原始路径字符串。

use {
    Router,
    extract::MatchedPath,
    routing::get,
};

let app = Router::new().route(
    "/users/:id",
    get(|path: MatchedPath| async move {
        let path = path.as_str();
        // `path` will be "/users/:id"
    })
);

Multipart: Extractor that parses multipart/form-data requests (commonly used with file uploads). 实现了 FromRequest<S> 消耗 body,所以只能作为 handler 函数最后一个参数且使用一次:

use axum::{
    extract::Multipart,
    routing::post,
    Router,
};
use futures_util::stream::StreamExt;

async fn upload(mut multipart: Multipart) {
    while let Some(mut field) = multipart.next_field().await.unwrap() {
        let name = field.name().unwrap().to_string();
        let data = field.bytes().await.unwrap();

        println!("Length of `{}` is {} bytes", name, data.len());
    }
}

let app = Router::new().route("/upload", post(upload));

NestedPath:Access the path the matched the route is nested at. 实现了 FromRequestParts:

use axum::{
    Router,
    extract::NestedPath,
    routing::get,
};

let api = Router::new().route(
    "/users",
    get(|path: NestedPath| async move {
        // `path` will be "/api" because thats what this router is nested at when we build `app`
        let path = path.as_str();
    })
);

let app = Router::new().nest("/api", api);

OriginalUri: Extractor that gets the original request URI regardless of nesting.

use axum::{
    routing::get,
    Router,
    extract::OriginalUri,
    http::Uri
};

let api_routes = Router::new()
    .route(
        "/users",
        get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
            // `uri` is `/users`
            // `original_uri` is `/api/users`
        }),
    );

let app = Router::new().nest("/api", api_routes);

Path: Extractor that will get captures from the URL and parse them using serde. 多个路径参数用 tuple 类型来表示。

use axum::{
    extract::Path,
    routing::get,
    Router,
};
use uuid::Uuid;

async fn users_teams_show(
    Path((user_id, team_id)): Path<(Uuid, Uuid)>,
) {
    // ...
}

let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));

Query:Extractor that deserializes query strings into some type . 将请求参数 Deserialize 到一个 struct type。(如果参数是可选的,需要使用 Option 类型)

use axum::{
    extract::Query,
    routing::get,
    Router,
};
use serde::Deserialize;

#[derive(Deserialize)]
struct Pagination {
    page: usize,
    per_page: usize,
}

// This will parse query strings like `?page=2&per_page=30` into `Pagination`
// structs.
async fn list_things(pagination: Query<Pagination>) {
    let pagination: Pagination = pagination.0;

    // ...
}

let app = Router::new().route("/list_things", get(list_things));

RawForm: Extractor that extracts raw form requests. 实现了 FromReqeust

use axum::{
    extract::RawForm,
    routing::get,
    Router
};

async fn handler(RawForm(form): RawForm) {}

let app = Router::new().route("/", get(handler));

RawPathParams: Extractor that will get captures from the URL without deserializing them.

pub struct RawPathParams(/* private fields */);
impl<'a> IntoIterator for &'a RawPathParams
    type Item = (&'a str, &'a str)

use axum::{
    extract::RawPathParams,
    routing::get,
    Router,
};

async fn users_teams_show(params: RawPathParams) {
    for (key, value) in &params {
        println!("{key:?} = {value:?}");
    }
}

let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));

RawQuery:Extractor that extracts the raw query string, without parsing it.

pub struct RawQuery(pub Option<String>);

// Extractor that extracts the raw query string, without parsing it.
// Example

use axum::{
    extract::RawQuery,
    routing::get,
    Router,
};
use futures_util::StreamExt;

async fn handler(RawQuery(query): RawQuery) {
    // ...
}

let app = Router::new().route("/users", get(handler));

State:Extractor for state. As state is global within a Router you can’t directly get a mutable reference to the state. The most basic solution is to use an Arc<Mutex<_>> . Which kind of mutex you need depends on your use case. See the tokio docs for more details.

use axum::{Router, routing::get, extract::State};

// the application state
//
// here you can put configuration, database connection pools, or whatever
// state you need
//
// see "When states need to implement `Clone`" for more details on why we need
// `#[derive(Clone)]` here.
#[derive(Clone)]
struct AppState {}

let state = AppState {};

// create a `Router` that holds our state
let app = Router::new()
    .route("/", get(handler))
    // provide the state so the router can access it
    .with_state(state);

async fn handler(
    // access the state via the `State` extractor
    // extracting a state of the wrong type results in a compile error
    State(state): State<AppState>,
) {
    // use `state`...
}

substate:State only allows a single state type but you can use FromRef to extract “substates”:

use axum::{Router, routing::get, extract::{State, FromRef}};

// the application state
#[derive(Clone)]
struct AppState {
    // that holds some api specific state
    api_state: ApiState,
}

// the api specific state
#[derive(Clone)]
struct ApiState {}

// support converting an `AppState` in an `ApiState`
impl FromRef<AppState> for ApiState {
    fn from_ref(app_state: &AppState) -> ApiState {
        app_state.api_state.clone()
    }
}

let state = AppState {
    api_state: ApiState {},
};

let app = Router::new()
    .route("/", get(handler))
    .route("/api/users", get(api_users))
    .with_state(state);

async fn api_users(
    // access the api specific state
    State(api_state): State<ApiState>,
) {
}

async fn handler(
    // we can still access to top level state
    State(state): State<AppState>,
) {
}

WebSocketUpgrade:Extractor for establishing WebSocket connections. 实现了 FromRequestParts

use axum::{
    extract::ws::{WebSocketUpgrade, WebSocket},
    routing::get,
    response::{IntoResponse, Response},
    Router,
};

let app = Router::new().route("/ws", get(handler));

async fn handler(ws: WebSocketUpgrade) -> Response {
    ws.protocols(["graphql-ws", "graphql-transport-ws"])
        .on_upgrade(|socket| async {
            // ...
        })
}


use axum::{
    extract::ws::{WebSocketUpgrade, WebSocket},
    routing::get,
    response::{IntoResponse, Response},
    Router,
};

let app = Router::new().route("/ws", get(handler));

async fn handler(ws: WebSocketUpgrade) -> Response {
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = if let Ok(msg) = msg {
            msg
        } else {
            // client disconnected
            return;
        };

        if socket.send(msg).await.is_err() {
            // client disconnected
            return;
        }
    }
}

// If you need to read and write concurrently from a WebSocket you can use StreamExt::split:
use axum::{Error, extract::ws::{WebSocket, Message}};
use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};

async fn handle_socket(mut socket: WebSocket) {
    let (mut sender, mut receiver) = socket.split();

    tokio::spawn(write(sender));
    tokio::spawn(read(receiver));
}

async fn read(receiver: SplitStream<WebSocket>) {
    // ...
}

async fn write(sender: SplitSink<WebSocket, Message>) {
    // ...
}

9 state
#

通过 State,可以在 Handler 间共享一些 State,如数据库连接池对象或其它 Client 等。

三种共享 State 的方式:

  1. 使用 State extractor;
  2. 使用 Request extensions;
  3. 使用闭包捕获机制;

9.1 State Extractor
#

可以在 Router、MethodRouter 和 Handler 三个层次上,通过 with_state() 方法添加 State 值。

// Router:
// pub fn with_state<S2>(self, state: S) -> Router<S2>

use axum::{
    extract::State,
    routing::get,
    Router,
};
use std::sync::Arc;

struct AppState {
    // ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
    .route("/", get(handler))
    .with_state(shared_state);

async fn handler(
    State(state): State<Arc<AppState>>,
) {
    // ...
}

当从函数返回 Router 时,一般建议不直接设置 State,而是在 run server 前再设置:

use axum::{Router, routing::get, extract::State};

#[derive(Clone)]
struct AppState {}

// Don't call `Router::with_state` here
fn routes() -> Router<AppState> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
}

// Instead do it before you run the server
let routes = routes().with_state(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

如果确实需要在返回 Router 的函数内设置 State,则返回的 State 不要加泛型参数,后续由编译器自动推断(如在 axum::serve() 中使用时,自动推断为 () 类型)::

// Don't return `Router<AppState>`
fn routes(state: AppState) -> Router {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let routes = routes(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

这是由于只能调用 Router::into_make_service() 只在 Router<()> 类型上实现,而不是 Router<AppState>。

Router state 缺省为 (), 所以 Router 等效于 Router<()>.

如果函数返回的 Router 在 nest() 方法中使用,则函数返回的 Router 需要使用无限界的泛型参数,后续由编译器自动推断(如在 axum::serve() 中使用时,自动推断为 () 类型):

fn routes<S>(state: AppState) -> Router<S> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let routes = Router::new().nest("/api", routes(AppState {}));

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

Router<S> 中的 S 含义?

pub struct Router<S = ()> { /* private fields */ }

// Router<S> 的方法, S 需要满足 Clone + Send + Sync + 'static
impl<S> Router<S> where S: Clone + Send + Sync + 'static
  // with_state<S2> 是泛型方法,如果未显式指定 S2 类型,则 Rust 根据上下文自动推断,
  // 例如,后续调用该 Router 的 into_make_service() 方法时,S2 自动推断为 ();
  pub fn with_state<S2>(self, state: S) -> Router<S2>

// Router 等效于 Router<S=()>
impl Router
  pub fn into_make_service(self) -> IntoMakeService<Self>
  pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>

Router<S> means a router that is missing a state of type S to be able to handle requests. It does not mean a Router that has a state of type S .

例如:


// A router that _needs_ an `AppState` to handle requests
let router: Router<AppState> = Router::new()
    .route("/", get(|_: State<AppState>| async {}));

// Once we call `Router::with_state` the router isn't missing the state anymore, because we just
// provided it
//
// Therefore the router type becomes `Router<()>`, i.e a router that is not missing any state
let router: Router<()> = router.with_state(AppState {});

// Only `Router<()>` has the `into_make_service` method.
//
// You cannot call `into_make_service` on a `Router<AppState>` because it is still missing an
// `AppState`.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

Router::with_state() 并不总是返回 Router<()> , 只有当调用它的 into_make_service() 方法时才需要 Router<()> 类型。

let router: Router<AppState> = Router::new()
    .route("/", get(|_: State<AppState>| async {}));

// When we call `with_state` we're able to pick what the next missing state type is.
// Here we pick `String`.
let string_router: Router<String> = router.with_state(AppState {});

// That allows us to add new routes that uses `String` as the state type
let string_router = string_router
    .route("/needs-string", get(|_: State<String>| async {}));

// Provide the `String` and choose `()` as the new missing state.
let final_router: Router<()> = string_router.with_state("foo".to_owned());

// Since we have a `Router<()>` we can run it.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, final_router).await.unwrap();

下面的例子报错,是因为 Router<AppState> 没有提供 Router::into_make_service() 方法:

// This wont work because we're returning a `Router<AppState>`
// i.e. we're saying we're still missing an `AppState`
fn routes(state: AppState) -> Router<AppState> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let app = routes(AppState {});

// We can only call `Router::into_make_service` on a `Router<()>`
// but `app` is a `Router<AppState>`
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

解决办法:返回 Router<()> 类型:

// We've provided all the state necessary so return `Router<()>`
fn routes(state: AppState) -> Router<()> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let app = routes(AppState {});

// We can now call `Router::into_make_service`
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

性能优化:如果需要一个实现 Service 的 Router,但是不需要任何 State,也建议调用 .with_state(()) 方法,这样可以让 axum 更新内部状态,减少分配内存,从而提升性能:

use axum::{Router, routing::get};

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    // even though we don't need any state, call `with_state(())` anyway
    .with_state(());

9.2 Request Extensions
#

Router 层次添加的 State 可以被所有该 Router 的 所有请求使用 。如果要 根据 Request 来生成特定请求相关的 State,例如从中间件生成的认证授权数据,则需要使用 Extension。

use axum::{
    Router,
    http::StatusCode,
    routing::get,
    response::{IntoResponse, Response},
    middleware::{self, Next},
    extract::{Request, Extension},
};

#[derive(Clone)]
struct CurrentUser { /* ... */ }

struct AppState {
    // ...
}

let shared_state = Arc::new(AppState { /* ... */ });

async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
    let auth_header = req.headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    let auth_header = if let Some(auth_header) = auth_header {
        auth_header
    } else {
        return Err(StatusCode::UNAUTHORIZED);
    };

    if let Some(current_user) = authorize_current_user(auth_header).await {
        // 使用 request extension 将当前用户信息插入,这样后续 Handler 使用 Extension 来提取
        req.extensions_mut().insert(current_user);
        Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

async fn authorize_current_user(auth_token: &str) -> Option<CurrentUser> {
    // ...
}

async fn handler(
    // extract the current user, set by the middleware
    Extension(current_user): Extension<CurrentUser>,
) {
    // ...
}

let app = Router::new()
    .route("/", get(handler))
    .route_layer(middleware::from_fn(auth))
    .layer(Extension(shared_state));

9.3 closure captures
#

State 也可以直接被 Handler 闭包捕获:

use axum::{
    Json,
    extract::{Extension, Path},
    routing::{get, post},
    Router,
};
use std::sync::Arc;
use serde::Deserialize;

struct AppState {
    // ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
    .route(
        "/users",
        post({
            let shared_state = Arc::clone(&shared_state);
            move |body| create_user(body, shared_state)
        }),
    )
    .route(
        "/users/:id",
        get({
            let shared_state = Arc::clone(&shared_state);
            move |path| get_user(path, shared_state)
        }),
    );

async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
    // ...
}

async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
    // ...
}

#[derive(Deserialize)]
struct CreateUserPayload {
    // ...
}

10 middleware
#

可以在 Router、MethodRouter 和 Handler 三个层次上,通过 layer() 方法添加 Layer 中间件:

  • 整个 routers:Router::layer() 和 Router::route_layer()
  • 单个 method router:MethodRouter::layer() 和 MethodRouter::route_layer();
  • 单个 handler:Handler::layer()

route_layer() 和 layer() 的差异:前者只在 router 匹配时才执行 layer 逻辑,而后者无论如何都执行 layer 逻辑。

pub trait Layer<S> {
    type Service;

    // S 为传入的 Service
    fn layer(&self, inner: S) -> Self::Service;
}

// Router 实现的 Layer trait:
pub fn layer<L>(self, layer: L) -> Router<S>
where
    L: Layer<Route> + Clone + Send + 'static,
    L::Service: Service<Request> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static

// 其中 Route 实现了 Service
// Request 是 Struct http::request::Request
impl<B, E> Service<Request<B>> for Route<E>
where
// B 是 Trait http_body::Body,其中 Data 是 bytes::Bytes
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>
// Responses given by the service. // Struct http::response::Response 类型
type Response = Response<Body>
type Error = E
type Future = RouteFuture<E>

可以直接复用tower_http crate 中预定义的 Layer, 如 tower_http::trace::TraceLayer

  • TraceLayer :用于 tracing/logging;
  • CorsLayer :用于处理 CORS;
  • CompressionLayer :用于自动压缩响应;
  • RequestIdLayerPropagateRequestIdLayer :用于设置和传播 request ids;
  • TimeoutLayer :用于超时控制;
// 示例
use axum::{
    routing::get,
    Router,
};
use tower_http::validate_request::ValidateRequestHeaderLayer;

let app = Router::new().route(
    "/foo",
    get(|| async {})
        .route_layer(ValidateRequestHeaderLayer::bearer("password"))
);

使用 Router.layer() 添加的 Layer, 按相反的顺序被依次调用,最后才调用 handler, 所以可以在 layer()中提取拒绝请求:

use axum::{routing::get, Router};
async fn handler() {}
let app = Router::new()
    .route("/", get(handler))
    .layer(layer_one)
    .layer(layer_two)
    .layer(layer_three); // 最先被调用执行

// 用 layer() 添加的 layer 按相反的顺序被一次调用,最后调用 handler
        requests
           |
           v
+----- layer_three -----+
| +---- layer_two ----+ |
| | +-- layer_one --+ | |
| | |               | | |
| | |    handler    | | |
| | |               | | |
| | +-- layer_one --+ | |
| +---- layer_two ----+ |
+----- layer_three -----+
           |
           v
        responses

但是使用 tower::ServiceBuilder 的 layer() 添加的中间件,按照 添加的顺序来执行 ,建议使用 ServiceBuilder 来创建含多个 Layer 的 Layer:

use tower::ServiceBuilder;
use axum::{routing::get, Router};
async fn handler() {}
let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(layer_one)
            .layer(layer_two)
            .layer(layer_three),
    );

创建自定义 Layer middleware 的 4 种方式:

  1. axum::middleware::from_fn/from_fn_with_state 使用闭包创建
  2. axum::middleware::from_extractor
  3. tower’s combinators,例如:
    1. ServiceBuilder::map_request
    2. ServiceBuilder::map_response
    3. ServiceBuilder::then
    4. ServiceBuilder::and_then
  4. tower::Service and Pin<Box<dyn Future>>

使用 axum::middleware::from_fn()/from_fn_with_state() 创建 Layer, 对于传入的闭包 f 有如下要求:

  • 是 async fn 闭包;
  • 传入 0 个或多个 FromRequestParts extractors.
  • 只能传入一个 FromRequest extractor 作为倒数第二个参数;
  • 使用 Next 作为最后一个参数;
  • 返回值需要实现 IntoResponse;
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T>

use axum::{
    Router,
    http,
    routing::get,
    response::Response,
    middleware::{self, Next},
    extract::Request,
};

async fn my_middleware(request: Request, next: Next,) -> Response {
    // do something with `request`...
    let response = next.run(request).await;
    // do something with `response`...
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(middleware::from_fn(my_middleware));

如果要 middleware 要使用 State,则使用 axum::middleware::from_fn_with_state() 来创建:

pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T>

use axum::{
    Router,
    http::StatusCode,
    routing::get,
    response::{IntoResponse, Response},
    middleware::{self, Next},
    extract::{Request, State},
};

#[derive(Clone)]
struct AppState { /* ... */ }

async fn my_middleware(
    State(state): State<AppState>,
    // you can add more extractors here but the last extractor must implement `FromRequest` which
    // `Request` does
    request: Request,
    next: Next,
) -> Response {
    // do something with `request`...
    let response = next.run(request).await;
    // do something with `response`...
    response
}

let state = AppState { /* ... */ };

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
    .with_state(state);

使用 axum::middleware::from_extractor()/from_extractor_with_state() 函数来从一个 extractor type 创建 middlware:

  • 如果 extractor 执行成功则继续处理, 否则出错返回。一般用于 validate 请求,可以复用已有的extractor type 类型;
  • 如果消耗 body, 则后续的 Router Service 在处理时获得的是空 body;
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()>

use axum::{
    extract::FromRequestParts,
    middleware::from_extractor,
    routing::{get, post},
    Router,
    http::{header, StatusCode, request::Parts},
};
use async_trait::async_trait;

// An extractor that performs authorization.
struct RequireAuth;

#[async_trait]
impl<S> FromRequestParts<S> for RequireAuth where S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let auth_header = parts
            .headers
            .get(header::AUTHORIZATION)
            .and_then(|value| value.to_str().ok());

        match auth_header {
            Some(auth_header) if token_is_valid(auth_header) => {
                Ok(Self)
            }
            _ => Err(StatusCode::UNAUTHORIZED),
        }
    }
}

fn token_is_valid(token: &str) -> bool {}
async fn handler() {}
async fn other_handler() {}

let app = Router::new()
    .route("/", get(handler))
    .route("/foo", post(other_handler))
    // The extractor will run before all routes
    .route_layer(from_extractor::<RequireAuth>());

使用 tower 提供的一些工具函数来创建 Layer, 它们一般用来做简单的 request 或 response 转换:

  • map_request()
  • map_request_with_state()
  • map_response()
  • map_response_with_state()

map_request() 示例:

use axum::{
    Router,
    routing::get,
    middleware::map_request,
    http::Request,
};

async fn set_header<B>(mut request: Request<B>) -> Request<B> {
    request.headers_mut().insert("x-foo", "foo".parse().unwrap());
    request
}

async fn handler<B>(request: Request<B>) {
    // `request` will have an `x-foo` header
}

let app = Router::new()
    .route("/", get(handler))
    .layer(map_request(set_header));

map_request() 的函数可以返回 Result,当为 Err 时会提前拒绝请求:

  • Request<B>
  • Result<Request<B>, E> where E: IntoResponse
use axum::{
    Router,
    http::{Request, StatusCode},
    routing::get,
    middleware::map_request,
};

async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
    let auth_header = request.headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    match auth_header {
        Some(auth_header) if token_is_valid(auth_header) => Ok(request),
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

fn token_is_valid(token: &str) -> bool {
    // ...
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .route_layer(map_request(auth));

map_request() 的函数的输入除了 Request 外,还可以包含其它 extractor:

use axum::{
    Router,
    routing::get,
    middleware::map_request,
    extract::Path,
    http::Request,
};
use std::collections::HashMap;

async fn log_path_params<B>(
    Path(path_params): Path<HashMap<String, String>>,
    request: Request<B>,
) -> Request<B> {
    tracing::debug!(?path_params);
    request
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_request(log_path_params));


use axum::{
    Router,
    http::{Request, StatusCode},
    routing::get,
    response::IntoResponse,
    middleware::map_request_with_state,
    extract::State,
};

#[derive(Clone)]
struct AppState { /* ... */ }

async fn my_middleware<B>(
    State(state): State<AppState>,
    // you can add more extractors here but the last extractor must implement `FromRequest` which
    // `Request` does
    request: Request<B>,
) -> Request<B> {
    // do something with `state` and `request`...
    request
}

let state = AppState { /* ... */ };

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .route_layer(map_request_with_state(state.clone(), my_middleware))
    .with_state(state);

map_response()

use axum::{
    Router,
    routing::get,
    middleware::map_response,
    response::Response,
};

async fn set_header<B>(mut response: Response<B>) -> Response<B> {
    response.headers_mut().insert("x-foo", "foo".parse().unwrap());
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(set_header));

// 和 map_request() 类似,异步函数中可以使用 extractor
use axum::{
    Router,
    routing::get,
    middleware::map_response,
    extract::Path,
    response::Response,
};
use std::collections::HashMap;

async fn log_path_params<B>(
    Path(path_params): Path<HashMap<String, String>>,
    response: Response<B>,
) -> Response<B> {
    tracing::debug!(?path_params);
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(log_path_params));

// 可以返回任何实现 impl IntoResponse 的对象
use axum::{
    Router,
    routing::get,
    middleware::map_response,
    response::{Response, IntoResponse},
};
use std::collections::HashMap;

async fn set_header(response: Response) -> impl IntoResponse {
    (
        [("x-foo", "foo")],
        response,
    )
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(set_header));

使用 tower::Service 和 Pin<Box<dyn Future>> 实现 Layer 和 Service,则具有最大的灵活性:

use axum::{
    response::Response,
    body::Body,
    extract::Request,
};
use futures_util::future::BoxFuture;
use tower::{Service, Layer};
use std::task::{Context, Poll};

// 定义 Layer 时,一般还要为 Layer 定义一个关联的 Service
#[derive(Clone)]
struct MyLayer;

impl<S> Layer<S> for MyLayer {
    type Service = MyMiddleware<S>; // 关联的 Service 类型

    fn layer(&self, inner: S) -> Self::Service {
        MyMiddleware { inner }
    }
}

#[derive(Clone)]
struct MyMiddleware<S> {
    inner: S,
}

// Service 的 request 必须是 http::request::Request 类型,response 必须是 http::response::Response
// , 才能满足 Routing.layer() 的要求。
impl<S> Service<Request> for MyMiddleware<S>
where
    S: Service<Request, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let future = self.inner.call(request);
        Box::pin(async move {
            let response: Response = future.await?;
            Ok(response)
        })
    }
}

11 error_handling
#

对于 Result<T, E> 类型,如果 T/E 都实现 IntoResponse 时,则该 Result 也实现 IntoResponse:

// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse
// struct axum::response::ErrorResponse 实现了 From<IntoResponse>, 即可以从 IntoResponse 转换为自
// 身的类型
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse
impl<T, U> IntoResponse for Chain<T, U> where
    T: Buf + Unpin + Send + 'static,
    U: Buf + Unpin + Send + 'static,

例如 axum::http::StatusCode 实现了 IntoResponse trait,故可以使用它作为 Result 的返回值:

use axum::http::StatusCode;

async fn handler() -> Result<String, StatusCode> {
    // ...
}

如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse;

// https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs

//! Run with
//!
//! ```not_rust
//! cargo run -p example-anyhow-error-response
//! ```

use axum::{
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::get,
    Router,
};

#[tokio::main]
async fn main() {
    let app = Router::new().route("/", get(handler));
    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap();
    println!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}


// Handler 返回的 Result 会被 IntoResponse 转换为  Response 返回给 client
async fn handler() -> Result<(), AppError> {
    try_thing()?;
    Ok(())
}

fn try_thing() -> Result<(), anyhow::Error> {
    anyhow::bail!("it failed!")
}

// Make our own error that wraps `anyhow::Error`.
struct AppError(anyhow::Error);

// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
    fn into_response(self) -> Response {
	// axum 为 (StatusCode, String) 实现了 IntoRespose
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Something went wrong: {}", self.0),
        )
            .into_response()
    }
}

// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
    E: Into<anyhow::Error>,
{
    fn from(err: E) -> Self {
        Self(err.into())
    }
}

axum 在使用 Service 时,要求 Service 关联的 Error 类型为 Infallible,即不可能发生错误。例如, Router.route_service() 使用 Service 作为请求的 Handler,该 Service 的约束是 Service<Request, Error = Infallible> + Clone + Send + 'static , 这里的 Error = Infallible ,表示不能返回任何实际的 Error。

// 例如:struct Router 的 route_service() 方法的 service 约束:
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
    // Error 是 Infallible 类型,即不可能返回错误
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    // Response 必须实现 IntoRsponse
    T::Response: IntoResponse,
    T::Future: Send + 'static,


// 示例:
use axum::{
    Router,
    body::Body,
    routing::{any_service, get_service},
    extract::Request,
    http::StatusCode,
    error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;

let app = Router::new()
    .route(
        // Any request to `/` goes to a service
        "/",
        // Services whose response body is not `axum::body::BoxBody` can be wrapped in
        // `axum::routing::any_service` (or one of the other routing filters) to have the response
        // body mapped

        // service_fn() 从闭包函数创建一个 Service,要求闭包函数的输入为 Request,输出为 Result。
        // 为了满足 any_service 的要求,Result 的 Error 必须为 Infallible .
        any_service(service_fn(|_: Request| async {
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        }))
    )
    .route_service(
        "/foo",
        // This service's response body is `axum::body::BoxBody` so it can be routed to directly.
        service_fn(|req: Request| async move {
            let body = Body::from(format!("Hi from `{} /foo`", req.method()));
            let res = Response::new(body);
            Ok::<_, Infallible>(res)
        })
    )
    .route_service(
        // GET `/static/Cargo.toml` goes to a service from tower-http
        "/static/Cargo.toml",
        ServeFile::new("Cargo.toml"),
    );

而一般通过 tower::service_fn(fn) 创建的 Service 约束是 FnMut(Request) -> Future<Output = Result<R, E>> ,所以可能包含具体的 Error 类型, 两者不匹配

为了能在 Router.route_service() 中使用 tower::service_fn(fn) 创建的返回 Result<R, E> 的 Service, axum 提供了 axum::error_handling::HandleError 来将它们转换为 Response。

HandleError 用于将一个返回 Result Error 的 Service 转换为 Service<Request, Error=Infallible> ,它是通过传入一个闭包,将 Result Err 转换为 IntoResponse,从而消除了 Err。

示例:

use axum::{
    Router,
    body::Body,
    http::{Request, Response, StatusCode},
    error_handling::HandleError,
};

async fn thing_that_might_fail() -> Result<(), anyhow::Error> {
    // ...
}

// 使用 service_fn 将异步函数转换为实现 Service 的 ServiceFn 对象, 而 ServiceFn 在实现 Service 时
// 对于 fn 的定义是 FnMut(Request) -> Future<Output = Result<R, E>>,所以 fn 可以返回包含 Error 的
// Result
let some_fallible_service = tower::service_fn(|_req| async {
    thing_that_might_fail().await?;
    Ok::<_, anyhow::Error>(Response::new(Body::empty()))
});

// route_service() 输入的 Service 的约束是:Service<Request, Error = Infallible> + Clone + Send +
// 'static, 这里的 Error = Infallible,与 tower::service_fn() 返回的 Error = xxx 不匹配,所以需要
// HandleError::new() 来进行转换。
let app = Router::new().route_service(
    "/",
    HandleError::new(some_fallible_service, handle_anyhow_error),
);

// 将 err 转换为实现 IntoResponse 对象的类型
async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
    (
        StatusCode::INTERNAL_SERVER_ERROR,
        format!("Something went wrong: {err}"),
    )
}

对于 layer middleware,也存在和 Service 类似的情况, axum::error_handling::HandleErrorLayer 提供了能处理 middleware Error 的转换能力。 HandleErrorLayer 实现了 Layer<S> 和 Service, 它内部使用 HandlerError 来将传入的返回 Error 的 Service 转换为 axum 使用的 Service<Request, Error = Infallible>

  • new(f) 输入是闭包函数 FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + ‘static, 输入是 0 或多个extractor + 最后一个 Error, 返回一个 IntoResponse 对象。
pub struct HandleErrorLayer<F, T> { /* private fields */ }

impl<F, T> HandleErrorLayer<F, T>
pub fn new(f: F) -> Self // f 闭包函数的输入是 Error,返回一个 IntoResponse 对象

impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
where
    F: Clone,
{
    type Service = HandleError<S, F, T>;

    fn layer(&self, inner: S) -> Self::Service {
        // self.f 传给 HandlerError::new(), 所以需要满足它对 F
        HandleError::new(inner, self.f.clone())
    }
}

impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
    S: Service<Request<B>> + Clone + Send + 'static,
    S::Response: IntoResponse + Send,
    S::Error: Send,
    S::Future: Send,
    F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse,
    B: Send + 'static,
{
    //...
}

示例:

  • pub type BoxError = Box<dyn Error + Sync + Send>;
use axum::{
    Router,
    BoxError,
    routing::get,
    http::StatusCode,
    error_handling::HandleErrorLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;

let app = Router::new()
    .route("/", get(|| async {}))
    .layer(
        ServiceBuilder::new()
            // `timeout` will produce an error if the handler takes too long so we must handle
	    // those new 传入的闭包函数的输入是 Error, 返回一个 IntoResponse 对象
            .layer(HandleErrorLayer::new(handle_timeout_error)) // 将 Err 转换为 IntoResponse
            .timeout(Duration::from_secs(30)) // 可能返回 Err 的 layer //
            handler .layer(TimeoutLayer::new(Duration::from_secs(10)))
    );

// 闭包函数的输入是 Error, 返回一个 IntoResponse 对象
async fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
    if err.is::<tower::timeout::error::Elapsed>() {
        (
            StatusCode::REQUEST_TIMEOUT,
            "Request took too long".to_string(),
        )
    } else {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Unhandled internal error: {err}"),
        )
    }
}

HandleErrorLayer 也只是使用 extractors:

use axum::{
    Router,
    BoxError,
    routing::get,
    http::{StatusCode, Method, Uri},
    error_handling::HandleErrorLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;

let app = Router::new()
    .route("/", get(|| async {}))
    .layer(
        ServiceBuilder::new()
            // `timeout` will produce an error if the handler takes too long so we must handle those
            .layer(HandleErrorLayer::new(handle_timeout_error))
            .timeout(Duration::from_secs(30))
    );

async fn handle_timeout_error( // 除了最后一个参数是 BoxError 外,前面的其它参数类型是 Extractor
    method: Method,
    uri: Uri,
    err: BoxError,
) -> (StatusCode, String) {
    (
        StatusCode::INTERNAL_SERVER_ERROR,
        format!("`{method} {uri}` failed with {err}"),
    )
}

12 参考
#

  1. Rust Axum Full Course
  2. https://github.com/AarambhDevHub/rust-backend-axum
  3. https://github.com/jeremychone-channel/rust-axum-course/tree/main
rust crate - 这篇文章属于一个选集。
§ 10: 本文

相关文章

reqwest
··4647 字
Rust Rust-Crate
reqwest 是在 hyper 基础上实现的高层 HTTP Client 库,支持异步和同步。
clap
··5510 字
Rust Rust-Crate
clap 用于快速构建命令行程序,提供命令&参数定义、解析等功能。
config
··1978 字
Rust Rust-Crate
config 提供从文件或环境变量解析配置参数的功能。
diesel
··34358 字
Rust Rust-Crate
diesel 是高性能的 ORM 和 Query Builder,crates.io 使用它来操作数据库。