跳过正文

axum

·
目录
rust crate - 这篇文章属于一个选集。
§ 10: 本文

axum 是基于 hyper 实现的异步 HTTP Server 库,使用流程:

  1. 创建 Router :通过 Route.route() 来定义 PATH 和关联的 ServiceService 一般使用 RouterMethod 类型实现,get/post/patch() 等函数返回 RouterMehtod 对象,这些函数的参数类型是 Handler

  2. 实现 Handler:一般由异步闭包函数实现:

    1. 输入:Extractor ,用来从请求中提取相关信息;
    2. 返回:实现 IntoResponse trait 的对象(而不是 Result)。axum 为 Rust 基本类型实现了该 trait;

Router/RouterMethod/Handler 三级都可以:

  1. 通过 layer() 方法来添加中间件,从而在调用 Handler 前先做一些处理;
  2. 通过 with_state() 添加状态对象;
use axum::{Router, routing::get};

let app = Router::new()
    .route("/", get(root))
    // get() 方法处理 GET 类型请求,返回 RouterMethod 类型(实现了 Service)。可以链式调用。
    .route("/foo", get(get_foo).post(post_foo))
    .route("/foo/bar", get(foo_bar));

// 实现 Handler trait 的闭包
async fn root() {}
async fn get_foo() {}
async fn post_foo() {}
async fn foo_bar() {}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

axum::serve
#

axum::serve()auxm 入口,它的第二个参数 make_serviceService 工厂(Service of Service)。

内层 Service 的输入是 http::request::Request<http_body::Body<Data=bytes::Bytes>>,响应是 http::response::Response<axum::body::Body>,也即从 HTTP Request 生成 Response

pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
where
    // 外层 Service:
    // Request 是 IncomingStream 类型(参考后文),Response 是另一个 Service
    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
    // 内层 Service:
    //   Request 是类型 http::request::Request<http_body::Body<Data=bytes::Bytes>>
    //   Response 是类型 http::response::Response<axum::body::Body>
    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
    S::Future: Send,

axum 的 Router/MethodRouter/Handler 等类型都满足 M 类型约束,也即实现了 Service 工厂, 它们的值均可以作为 serve() 的参数。

Router 为例,它实现了两层 Service trait,满足 M 的限界要求:

// Router 实现了 Service<IncomingStream<'_>>,对应的 Response 还是 Router 类型
impl Service<IncomingStream<'_>> for Router<()>
  type Response = Router // 响应为 Router
  type Error = Infallible
  type Future = Ready<Result<<Router as Service<IncomingStream<'_, L>>>::Response, <Router as Service<IncomingStream<'_, L>>>::Error>>

// Router 又实现了 Service<Request<B>>, 满足上面 M 的 Response=S 中的 S 约束:
// Request 为 http::request::Request<http_body::Body<Data=bytes::Bytes>>
// Response 为 http::response::Response<axum::body::Body>
impl<B> Service<Request<B>> for Router<()>
where
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>
  type Response = Response<Body> // Body 为 axum::body::Body 类型
  type Error = Infallible
  type Future = RouteFuture<Infallible>

示例:

// Serving a Router:
use axum::{Router, routing::get};
let router = Router::new().route("/", get(|| async { "Hello, World!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

// Serving a MethodRouter:
use axum::routing::get;
let router = get(|| async { "Hello, World!" });
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();

Router
#

Router<S=()> 用于定义 PATH 和处理逻辑,S 为 State 的类型,可以是实现 Clone 的任意自定义类型(一般是 Arc)。

// S 为 Router 的 State,缺省为 ();
pub struct Router<S = ()> { /* private fields */ }

处理逻辑的实现方式:

  1. route(): 使用 MethodRouter<S> 处理逻辑, 一般通过 axum::routing::method_routing module 提供的各种 HTTP Method 命名的方法来实现, 例如:

    • axum::routing::method_routing::get(handler: Handler<T, S>):handler 一般通过闭包来实现;
    • axum::routing::method_routing::get_service(svc: Service<Request>): svc 一般通过 tower::service_fn() 闭包来实现;
    • axum::routing::method_routing::any_service(svc: Service<Request>):任意请求类型的 svc,一般通过 tower::service_fn() 闭包来实现;
    • axum::routing::method_routing::on_service<T, S>(filter: MethodFilter, svc: T,):指定 Method 和 svc,svc 一般通过 tower::service_fn() 闭包来实现;
  2. router_service(): 使用 tower::Service 处理逻辑, 一般通过 tower::service_fn() 闭包来实现, 或者直接复用 tower_http crate 提供的 Service, 如 tower_http::services::ServeFile

  3. layer()/route_layer(): 使用 Layer<Route> 处理逻辑, 一般通过 axum::middleware::from_fn()/from_fn_with_state() 闭包来实现, 或者直接复用 tower_http crate 提供的 Layer, 如 tower_http::trace::TraceLayer;

  • 不直接使用 tower::layer_fn() 来实现,因为它的 Request/Response 是任意类型, 可能不满足 axum 要求的 http::request::Request/http::request::Response 类型。

总结:上面的各种处理逻辑类型, 如 Handler, Service, Layer,都可以使用闭包来快速实现。

// S 是后续 with_state(state: S) 方法传入的 state 对象的类型, 需要实现 Clone, 所以一般使用 Arc 类型。
impl<S> Router<S> where S: Clone + Send + Sync + 'static

// 以下方法是在 Router<S=()> 类型上定义的:

// 创建一个 Router<S=()> 类型的 Router 对象, 这里的 S 类型默认为 ()
pub fn new() -> Self

// 添加一个对 path 的 MethodRouter 处理逻辑,可以使用 any_service() 函数从闭包创建
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self

// 添加一个对 path 的 Service 处理,可以使用 tower::service_fn() 从闭包函数创建, 也可以使用 tower_http 的各种 middleware。
//
// 注意:Service 的 Error = Infallible, 如果 tower::service_fn() 闭包的返回值 Result 的 Err 不为 Infallible,则不满足该 Serivce 限界,需要使用 axum::error_handling::HandleError 来转换。
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 添加嵌套的 Router
pub fn nest(self, path: &str, router: Router<S>) -> Self

pub fn nest_service<T>(self, path: &str, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 将多个 Router 的合并到一起
pub fn merge<R>(self, other: R) -> Self where R: Into<Router<S>>

// 为 Router 的所有 Route 添加 layer middleware(而不管它们是否匹配该 Router)。
// 按照添加的反序来调用,最后才调用 handler。
// layer() 获取所有权,返回一个新 Router<S>,故可以链式调用。
pub fn layer<L>(self, layer: L) -> Router<S>
where
    L: Layer<Route> + Clone + Send + 'static,
    L::Service: Service<Request> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static

// 只为匹配 Route 的请求添加 layer。
pub fn route_layer<L>(self, layer: L) -> Self
where
    L: Layer<Route> + Clone + Send + 'static,
    L::Service: Service<Request> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static

// 没有匹配的 Route 时 fallback 到的 handler
pub fn fallback<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static
pub fn fallback_service<T>(self, service: T) -> Self
where
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    T::Response: IntoResponse,
    T::Future: Send + 'static

// 为 Router 提供 state,后续通过 handler 的 extract 来获得该 state, state 可以是实现 Clone 的任意自定义对象类型。
// 为 Router 的所有请求提供全局数据,不适合给特定单个请求提供数据(用 Extension)。
// S2 的类型是靠后续使用 Router 的方式自动推导的。
pub fn with_state<S2>(self, state: S) -> Router<S2>

// 在 Router 上直接调用 tower::ServiceExt 的方法会报错(隐式类型推导失败),需要调用该方法来解决。
pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S>

// !!! 以下两个方法是在 Router<()> 类型上定义的: 将 Route 转换为 MakeService, 它是创建另一个 Service 的 Service, 主要的使用场景是作为 axum::serve 的参数。
pub fn into_make_service(self) -> IntoMakeService<Self>
pub fn into_make_service_with_connect_info<C>( self,) -> IntoMakeServiceWithConnectInfo<Self, C>

Router 虽然实现了 Service<Request<B>>Service<http::request::Request<B:http_body::Body<Data=bytes.Bytes>> , 但是直接调用 tower::ServiceExt 的方法会报错(如 service.ready().wait?.call(request))。

解决办法: 使用 Routeras_service() 方法返回的 Service:

use axum::{
    Router,
    routing::get,
    http::Request,
    body::Body,
};
use tower::{Service, ServiceExt};

let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());

// let response = router.ready().await?.call(request).await?;
//                       ^^^^^ cannot infer type for type parameter `B`

// OK
let response = router.as_service().ready().await?.call(request).await?;

Router<S=()> 提供了 into_make_service()/into_make_service_with_connect_info() 方法来创建一个实现 MakeService trait 的类型:

// 返回一个 Service 工厂类型, 泛型参数 Self 为 Router 类型
pub fn into_make_service(self) -> IntoMakeService<Self>

// S 为 Router 类型时(因为是调用 Router::into_make_service() 方法),Response 也为 Router 类型,
// 而 Router 实现了 MakeService trait. 所以 IntoMakeService 实现了 Service 和 MakeService trait。
impl<S, T> Service<T> for IntoMakeService<S> where S: Clone
  type Response = S
  type Error = Infallible
  type Future = IntoMakeServiceFuture<S>


// 示例
use axum::{
    routing::get,
    Router,
};
let app = Router::new().route("/", get(|| async { "Hi!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// axum::serve(listener, app).await.unwrap();
axum::serve(listener, app.into_make_service()).await.unwrap();

更多示例:

use axum::{
    Router,
    body::Body,
    routing::{get, delete, any_service, get_service, MethodFilter, on_service},
    extract::{Request, Path, State},
    http::StatusCode,
    error_handling::HandleErrorLayer,};

use tower::{Service, ServiceExt};
use tower::service_fn;

use tower_http::services::ServeFile;
use tower_http::trace::TraceLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;

use http::Response;
use std::{convert::Infallible, io};

let app = Router::new()
    .route("/", get(root))
    // get() 返回 MethodRouter,可以链式调用
    .route("/users", get(list_users).post(create_user))
    .route("/users/:id", get(show_user))
    .route("/api/:version/users/:id/action", delete(do_users_action))
    // /assets/*path 不匹配 /assets/
    .route("/assets/*path", get(serve_asset));

async fn root() {}
async fn list_users() {}
async fn create_user() {}
async fn show_user(Path(id): Path<u64>) {}
// 多个 Path 参数用 tuple 表示
async fn do_users_action(Path((version, id)): Path<(String, u64)>) {}
async fn serve_asset(Path(path): Path<String>) {}

let app = Router::new()
    .route( "/", any_service( // any_service() 从闭包创建一个 MethodRouter 对象
        // 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样才匹配类型约束。否则需要使用
        // axum::error_handling::HandleError 来转换 Err 为 Infallible。
        service_fn(|_: Request| async {
            // 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response 对象。
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        })))
    .route_service( "/foo", service_fn(|req: Request| async move {
        // 使用 axum::body::Body 类型,它实现了 http_body::Body trait
        let body = Body::from(format!("Hi from `{} /foo`", req.method()));
        let res = Response::new(body);
        Ok::<_, Infallible>(res)
    }))
    .route_service( "/static/Cargo.toml", ServeFile::new("Cargo.toml"), );

// on_service 是通用的请求函数,需要指定具体的 HTTP Method,
// any_service 是更通用的请求函数, 匹配任意类型 HTTP Method。
let service = tower::service_fn(|request: Request| async { Ok::<_, Infallible>(Response::new(Body::empty()))});
let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));

// merge 两个 Router
let user_routes = Router::new() .route("/users", get(users_list)) .route("/users/:id", get(users_show));
let team_routes = Router::new() .route("/teams", get(teams_list));
let app = Router::new().merge(user_routes) .merge(team_routes);

let app = Router::new()
    .route("/foo", get(|| async {}))
    .route("/bar", get(|| async {}))
    .layer(TraceLayer::new_for_http()); // 对上面两个 path 的 route 均生效

let app = Router::new()
    .route("/foo", get(|| async {}))
    .route_layer(ValidateRequestHeaderLayer::bearer("password")); // 只对 /foo 生效
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `GET /not-found` with a invalid token will receive `404 Not Found` // 因为没有匹配 /not-found 的 Router,所以返回 404

let app = Router::new()
    .route("/foo", get(|| async { /* ... */ }))
    .fallback(fallback); // 未匹配的 path 的缺省 handler
async fn fallback(uri: Uri) -> (StatusCode, String) {}

// state 对象必须实现 Clone trait
#[derive(Clone)]
struct AppState {}

let routes = Router::new()
    // 使用 axum::extract::State 来为请求获得 global state
    .route("/", get(|State(state): State<AppState>| async {
        // 使用 state
    }))
    .with_state(AppState {});

IncomingStream
#

Struct axum::serve::IncomingStreamMethodRouter/HandlerService 等类型在实现 Service 时的 Request 类型,如:impl Service<IncomingStream<'_>> for MethodRouter<()>

IncomingStream 封装了 TCP/Unix Stream Listener,对外提供 accept()local_addr() 方法:

pub struct IncomingStream<'a, L> where L: Listener, { /* private fields */ }

// IncomingStream 实现了如下方法:
// TCP/Unix Stream, 可以获得 local addr 等信息
pub fn io(&self) -> &L::Io
// 连接的 client addr
pub fn remote_addr(&self) -> &L::Addr

// https://docs.rs/axum/latest/axum/serve/trait.Listener.html
pub trait Listener: Send + 'static {
    type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
    type Addr: Send;

    // Required methods
    fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
    fn local_addr(&self) -> Result<Self::Addr>;
}

// tokio::net::TcpListener 和 UdpListener 实现了 Listener trait:
impl Listener for TcpListener
    type Io = TcpStream
    type Addr = SocketAddr
    async fn accept(&mut self) -> (Self::Io, Self::Addr)
    fn local_addr(&self) -> Result<Self::Addr>

impl Listener for UnixListener
    type Io = UnixStream
    type Addr = SocketAddr
    async fn accept(&mut self) -> (Self::Io, Self::Addr)
    fn local_addr(&self) -> Result<Self::Addr>

ConnectInfo
#

Router<()>into_make_service_with_connect_info<C>() 方法的使用场景是为 HandlerConnectInfo extractor 提供信息(通过 HTTP Request Extensions 实现)client 连接信息,如 socket 地址:

// 在 Router<()> 类型上定义的方法:
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>

// 示例
use axum::{
    extract::ConnectInfo,
    routing::get,
    Router,
};
use std::net::SocketAddr; // 直接支持

// new() 创建一个 Router 对象, 但是没有 state 的类型信息, 所以后续根据使用 app 的方式来推导, 而
// app.into_make_service_with_connect_info() 方法是在 Router<()> 类型上定义的, 所以 app 被推导为 Router<()> 类型.
let app = Router::new().route("/", get(handler));

async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
    format!("Hello {addr}")
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.unwrap();

可以通过实现 Connected<IncomingStream<'_>> trait 来自定义 ConnectInfo 类型。

axum 为 std::net::SocketAddrIncomingStream<'_> 类型实现了 Connected trait,可以直接使用;

// 实例:
use axum::{
    extract::connect_info::{ConnectInfo, Connected},
    routing::get,
    serve::IncomingStream,
    Router,
};

let app = Router::new().route("/", get(handler));

async fn handler(ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>, ) -> String {
    format!("Hello {my_connect_info:?}")
}

#[derive(Clone, Debug)]
struct MyConnectInfo {
    // ...
}

impl Connected<IncomingStream<'_>> for MyConnectInfo {
    fn connect_info(target: IncomingStream<'_>) -> Self {
        // target 封装了 TCP/Unix Stream 和 local_addr/remote_addr 信息
        MyConnectInfo {
            // ...
        }
    }
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();

body
#

Handler 需要返回实现 IntoResponse trait 的对象,而 http::Response<Body: http_body::Body> 是泛型类型,需要指定 Body 的具体类型。

pub trait IntoResponse {
      // Required method
      fn into_response(self) -> http::Response<auxm::body::Body>;
}

axum 使用自己定义的,实现了 http_body::Body<bytes::Bytes> traitStruct axum::body::Body 类型作为 Handler 返回值。

  • axum::body::Body 实现 FromRequest trait,可以用作 extractor,同时也实现了 IntoResponse trait,可以用作 Handler 响应。
// Struct axum::body::Body
pub struct Body(/* private fields */);

// 从实现了 http_body::Body 的对象创建 auxm::body::Body 对象, new() 是泛型方法, 泛型类型 B 需要实现 http_body::Body<Data>
// trait, 且 Data 缺省类型为 bytes::Bytes, 一般情况下 B 的类型为 http_body_util::Full/http_body_util::empty 等类型.
pub fn new<B>(body: B) -> Body where B: Body<Data = Bytes> + Send + 'static, <B as Body>::Error: Into<Box<dyn Error + Sync + Send>>

// 将 Body Data 部分转换为 Stream(丢弃 body 中的 trailers 内容) 可以使用 http_body_util/struct.BodyStream 来将 Body 转
// 换为带有 trailers 的 https://docs.rs/http-body/1.0.1/http_body/struct.Frame.html 类型.
pub fn into_data_stream(self) -> BodyDataStream

// 其它快速创建 axum::body::Body 的方式
impl From<&'static [u8]> for Body
impl From<&'static str> for Body
impl From<()> for Body
impl From<Bytes> for Body
impl From<Cow<'static, [u8]>> for Body
impl From<Cow<'static, str>> for Body
impl From<String> for Body
impl From<Vec<u8>> for Body

// axum::body::Body 实现了 FromRequest trait,故可以作为 extractor 来使用
impl<S> FromRequest<S> for Body where S: Send + Sync
  type Rejection = Infallible
  fn from_request<'life0, 'async_trait>(req: Request<Body>, _: &'life0 S) 
    -> Pin<Box<dyn Future<Output = Result<Body, <Body as FromRequest<S>>::Rejection>> + Send + 'async_trait>>  'life0: 'async_trait Body: 'async_trait

// axum::body::Body 实现了 IntoResponse,可以作为 Handler 返回值
impl IntoResponse for Body
    fn into_response(self) -> Response<Body>

// 示例
let app = Router::new()
    .route( "/", any_service(
        // 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样才匹配类型约束。否则需要使用
        // axum::error_handling::HandleError 来转换 Err 为 Infallible。
        service_fn(|_: Request| async {
            // 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response<T> 对象
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        })))

MethodRouter
#

MethodRouterRouter::route(path, method_router) 方法的参数类型,为 path 提供处理逻辑。

MethodRouter<S, Infallible> 封装了请求 MethodHandler 处理逻辑,可以链式调用,根据 Method 来调用不同的 Hander 来处理。

// S 是 State 的类型,一般由 Router<S> 传递下来。
impl<S> MethodRouter<S, Infallible> where S: Clone

auxm::routing module 提供了一些快捷函数,如 any()/any_service()/get()/get_service()/delete()/delete_service()/put()/post() 等,可以用来快速创建对应 HTTP MethodMethodRouter 对象:

// Re-exports
pub use self::method_routing::any;
pub use self::method_routing::any_service;
pub use self::method_routing::delete;
pub use self::method_routing::delete_service;
pub use self::method_routing::get;
pub use self::method_routing::get_service;
pub use self::method_routing::head;
pub use self::method_routing::head_service;
pub use self::method_routing::on;
pub use self::method_routing::on_service;
pub use self::method_routing::options;
pub use self::method_routing::options_service;
pub use self::method_routing::patch;
pub use self::method_routing::patch_service;
pub use self::method_routing::post;
pub use self::method_routing::post_service;
pub use self::method_routing::put;
pub use self::method_routing::put_service;
pub use self::method_routing::trace;
pub use self::method_routing::trace_service;

pub use self::method_routing::MethodFilter; // on()/on_service() 使用的请求方法类型(关联常量)

// MethodRouter 的方法:
//
// on()/on_service() 是通用方法,MethodFilter 是枚举类型,为标准的 HTTP Method。
// Request 是 http::Request<axum::body::Body>  类型。
pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
    where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
    where
        T: Service<Request, Error = E> + Clone + Send + 'static,
        T::Response: IntoResponse + 'static,
        T::Future: Send + 'static

// any()/any_service() 是限制 Method 的类型方法。
pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>where
    H: Handler<T, S>,
    T: 'static,
    S: Clone + Send + Sync + 'static
pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>where
    T: Service<Request> + Clone + Send + Sync + 'static,
    T::Response: IntoResponse + 'static,
    T::Future: Send + 'static,
    S: Clone,

// 其它返回 MethodRouter 的方法,它们是 on() 方法的封装,可以链式调用
pub fn delete<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn get<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn head<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static

tower::service_fn(f: T) 的参数 f 是实现 FnMut(Request) -> Future<Output = Result<R, E>> 的闭包:

  • 在 axum 场景中,Request 需要是 http::Request<axum::body::Body> 类型,或它的类型别名 axum::extract::Request<axum::body::Body>
// tower::service_fn() 定义
pub fn service_fn<T>(f: T) -> ServiceFn<T> // T 是一个返回 Future 的异步函数。

// ServiceFn<T> 实现了 Service trait,其中 Request 和响应 R 都是泛型类型,
// 在 axum 场景中,Request 需要是 http::Request<axum::body::Body> 类型,或它的类型别名 axum::extract::Request<axum::body::Body>
impl<T, F, Request, R, E> Service<Request> for ServiceFn<T> where  T: FnMut(Request) -> F, F: Future<Output = Result<R, E>>,
{
    type Response = R;
    type Error = E;
    type Future = F;

    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), E>> {
        Ok(()).into()
    }

    fn call(&mut self, req: Request) -> Self::Future {
        (self.f)(req)
    }
}

get()/get_service() 为例:

// `get()`: 使用 `Handler` 处理逻辑;
pub fn get<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
where
    H: Handler<T, S>,
    T: 'static,
    S: Clone + Send + Sync + 'static

// `get_service()`: 使用 `Service` 处理逻辑, 一般通过 `tower::service_fn()` 闭包来实现;
// Request 是 axum::extract::Request<axum::body::Body> 类型,而它是 http::Request<axum::body::Body> 的类型别名。
pub fn get_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
where
    T: Service<Request> + Clone + Send + 'static,
    T::Response: IntoResponse + 'static,
    T::Future: Send + 'static,
    S: Clone,

// 示例:
use axum::{
    routing::get,
    routing::get_service,
    extract::Request,
    body::Body,
    Router,
};

async fn handler() {}
async fn other_handler() {}

let service = tower::service_fn(|request: Request| async {
    // 返回值 Result 的 Err 必须是 Infallible
    Ok::<_, Infallible>(Response::new(Body::empty()))
});

let app = Router::new()
    .route("/", get(handler))
    .route("/svc", get_service(service).on(MethodFilter::DELETE, other_handler));

MethodRouter 支持添加 statelayer,且只对该 MethodRouterHandler 有效:

// 关联 State
pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E>

// 关联 Layer,Layer 在 Handler 之前处理
pub fn route_layer<L>(self, layer: L) -> MethodRouter<S, E>
where
    L: Layer<Route<E>> + Clone + Send + 'static,
    L::Service: Service<Request, Error = E> + Clone + Send + 'static,
    <L::Service as Service<Request>>::Response: IntoResponse + 'static,
    <L::Service as Service<Request>>::Future: Send + 'static,
    E: 'static,
    S: 'static,

// 示例
use axum::{ routing::get, Router, };
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new().route(
    "/foo",
    get(|| async {}).route_layer(ValidateRequestHeaderLayer::bearer("password"))
);
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `POST /FOO` with a invalid token will receive `405 Method Not Allowed`

MethodRouter 只在 state 为 () 时实现了 Service trait。对于非 () 的情况, 需要使用 with_state() 传入对应类型的 state 值后,该 MethodRouter 才实现 Service

impl<S, E> MethodRouter<S, E> where  S: Clone,
    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E>

// MethodRouter 只在 state 为 () 时实现了 Service trait
impl<B, E> Service<Request<B>> for MethodRouter<(), E>
where
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>
impl<L> Service<IncomingStream<'_, L>> for MethodRouter<()>
Where
    L: Listener,

// 示例:
use tower::Service;
use axum::{routing::get, extract::{State, Request}, body::Body};

// this `MethodRouter` doesn't require any state, i.e. the state is `()`,
let method_router = get(|| async {});
// and thus it implements `Service`
assert_service(method_router);

// this requires a `String` and doesn't implement `Service`
let method_router = get(|_: State<String>| async {});
// until you provide the `String` with `.with_state(...)`
let method_router_with_state = method_router.with_state(String::new());
// and then it implements `Service`
assert_service(method_router_with_state);

// helper to check that a value implements `Service`
fn assert_service<S>(service: S) where S: Service<Request>, {}

Handler
#

Handler traitMethodRouter 各方法 on/get/delete/put/post(handler) 的参数实现的处理逻辑,一般由 async 闭包函数实现。

  • Request 是类型别名: pub type Request<T = axum::body::Body> = http::Request<T>;
  • Response 是类型别名: pub type Response<T = axum::body::Body> = http::Response<T>;
pub trait Handler<T, S>: Clone + Send + Sized + 'static {
    type Future: Future<Output = Response> + Send + 'static;

    // Required method
    fn call(self, req: Request, state: S) -> Self::Future;

    // Provided methods
    // 返回的 Layered 也实现了 Handler,后续先执行传入的 layer 再执行本 Handler
    fn layer<L>(self, layer: L) -> Layered<L, Self, T, S> where L: Layer<HandlerService<Self, T, S>> + Clone, L::Service: Service<Request>

    fn with_state(self, state: S) -> HandlerService<Self, T, S> { ... }
}

// MethodRouter 实现了 Handler
impl<S> Handler<(), S> for MethodRouter<S> where S: Clone + 'static

Handler 提供了 layer()with_state() 方法,用来为该 Handler 添加中间件(在执行 handler 闭包前执行)和状态:

  • State 需要满足 Clone + Send + Sync + 'static ,一般需要获得 State 的所有权(如 Arc)才能满足要求。
// 示例:
use axum::{routing::get, handler::Handler, Router, };
use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit};

async fn handler() { /* ... */ }

// 返回的 layered_handler 实现了 Handler,后续先执行传入的 layer 再执行 handler。
let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64));
let app = Router::new().route("/", get(layered_handler));

with_state() 返回 HandlerService<Self, T, S> 类型对象,它提供了 into_make_service()/into_make_service_with_connect_info<C>() 方法。

HandlerService 用于将 Handler 对象转换为 Service of Service 对象, 可以作为 axum::serve() 的参数:

// An adapter that makes a Handler into a Service.
// Created with Handler::with_state or HandlerWithoutStateExt::into_service.
pub struct HandlerService<H, T, S> { /* private fields */ }

impl<H, T, S> HandlerService<H, T, S>
    pub fn state(&self) -> &S
    pub fn into_make_service(self) -> IntoMakeService<HandlerService<H, T, S>>
    pub fn into_make_service_with_connect_info<C>(self)->IntoMakeServiceWithConnectInfo<HandlerService<H,T,S>, C>

// HandlerService 也是 Service of Service 工厂, 可以直接给 axum::serve() 使用。
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
where
    H: Clone,
    S: Clone
type Response = HandlerService<H, T, S> // 返回自身类型
type Error = Infallible

impl<H, T, S, B> Service<Request<B>> for HandlerService<H, T, S>
where
    H: Handler<T, S> + Clone + Send + 'static,
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>,
    S: Clone + Send + Sync
type Response = Response<Body>
type Error = Infallible

// with_state() 示例
use axum::{
    handler::Handler,
    response::IntoResponse,
    extract::{ConnectInfo, State},
};
use std::net::SocketAddr;

// State 需要实现 Clone
#[derive(Clone)]
struct AppState {};

async fn handler( ConnectInfo(addr): ConnectInfo<SocketAddr>, State(state): State<AppState>, ) -> String {
    format!("Hello {addr}")
}

let app = handler.with_state(AppState {}); // 为 handler 关联 State

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// axum::handler::HandlerWithoutStateExt trait 为 Handler 提供了 into_service()/into_make_service()/into_make_service_with_connect_info() 等方法,返回 HandlerService。
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>(),).await.unwrap();

axum::handler::HandlerWithoutStateExt trait 为没有 State 的 Handler<T, ()> 对象提供了 into_service()/into_make_service() 方法,返回 HandlerService

pub trait HandlerWithoutStateExt<T>: Handler<T, ()> {
    // Required methods
    fn into_service(self) -> HandlerService<Self, T, ()>;
    fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, ()>>;
    fn into_make_service_with_connect_info<C>(self, ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, ()>, C>;
}

// 示例:
// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();

axum 默认为 16 个参数内的 async FnOnce 闭包实现了 Handler trait:

  • 闭包的输入参数是 extractor,可以有多个,但前面的参数必须实现 FromRequestParts trait ,最后一个参数实现 FromRequest trait
  • 闭包返回的 Output 必须实现 Future<Output = IntoResponse> + Clone + Send + 'static , 所以一般不包含借用, 而是使用 move 将所有权转移到闭包中;

这些 async FnOnce 函数或闭包返回的结果是 Future<Output = IntoResponse> ,而不是 Result 类型,所以如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse trait

impl<F, Fut, Res, S> Handler<((),), S> for F
where
    F: FnOnce() -> Fut + Clone + Send + 'static, // 闭包函数
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse

// 如果是一个参数,则需要实现 FromRequest
impl<F, Fut, S, Res, M, T1> Handler<(M, T1), S> for F
where
    F: FnOnce(T1) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequest<S, M> + Send

// 对于多个参数,前面的参数实现 FromRequestParts, 最后一个参数实现 FromRequest
impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
    F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequestParts<S> + Send,
    T2: FromRequest<S, M> + Send

// 一直到 16 个输入参数的闭包类型
impl<F, Fut, S, Res, M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16> Handler<(M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16), S> for F
where
    F: FnOnce(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse,
    T1: FromRequestParts<S> + Send,
    T2: FromRequestParts<S> + Send,
    T3: FromRequestParts<S> + Send,
    T4: FromRequestParts<S> + Send,
    T5: FromRequestParts<S> + Send,
    T6: FromRequestParts<S> + Send,
    T7: FromRequestParts<S> + Send,
    T8: FromRequestParts<S> + Send,
    T9: FromRequestParts<S> + Send,
    T10: FromRequestParts<S> + Send,
    T11: FromRequestParts<S> + Send,
    T12: FromRequestParts<S> + Send,
    T13: FromRequestParts<S> + Send,
    T14: FromRequestParts<S> + Send,
    T15: FromRequestParts<S> + Send,
    T16: FromRequest<S, M> + Send

async 闭包实现的 Handler 示例:

use axum::{body::Bytes, http::StatusCode};

// 空返回值表示返回 200 OK,body 为空
async fn unit_handler() {}

// 返回 String 表示返回 200 OK 的同时返回纯文本 body
async fn string_handler() -> String {
    "Hello, World!".to_string()
}

// echo handler 只有一个参数,它必须实现 FromRequest(Bytes 满足)。
// String 和 StatusCode 都实现了 IntoResponse, 故 Result<String, StatusCode> 也实现了 IntoResponse。
async fn echo(body: Bytes) -> Result<String, StatusCode> {
    if let Ok(string) = String::from_utf8(body.to_vec()) {
        Ok(string)
    } else {
        Err(StatusCode::BAD_REQUEST)
    }
}

IntoResponse
#

Handler 闭包函数的返回值需要实现 IntoResponse trait

pub trait IntoResponse {
    // Required method
    fn into_response(self) -> Response<Body>; // Body 为 struct axum::body::Body
}

如两个参数的闭包示例:

impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
    F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    S: Send + Sync + 'static,
    Res: IntoResponse, // 需要实现 IntoResponse
    T1: FromRequestParts<S> + Send,
    T2: FromRequest<S, M> + Send

axum 为 Rust 基本类型实现了 IntoResponse trait

impl<const N: usize> IntoResponse for &'static [u8; N]
impl IntoResponse for &'static [u8]
impl<const N: usize> IntoResponse for [u8; N]
impl IntoResponse for &'static str
impl IntoResponse for &'static [u8]
impl IntoResponse for Cow<'static, str>
impl IntoResponse for Cow<'static, [u8]>
impl IntoResponse for Infallible
impl IntoResponse for ()
impl IntoResponse for Box<str>
impl IntoResponse for Box<[u8]>
impl IntoResponse for String
impl IntoResponse for Vec<u8>
impl IntoResponse for Bytes
impl IntoResponse for BytesMut
impl IntoResponse for Extensions
impl IntoResponse for HeaderMap
impl IntoResponse for Parts
impl IntoResponse for StatusCode

// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse
// struct axum::response::ErrorResponse 也实现了 From<IntoResponse>
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse

impl<T, U> IntoResponse for Chain<T, U> where
    T: Buf + Unpin + Send + 'static,
    U: Buf + Unpin + Send + 'static,

// http::response::Response<bytes::Bytes> 实现了 IntoResponse
impl<B> IntoResponse for Response<B>
where
    B: Body<Data = Bytes> + Send + 'static,
    <B as Body>::Error: Into<Box<dyn Error + Send + Sync>>

// (K, V) 数组也实现 IntoResponse: K 可转换为 HeaderName, V 可转换为 HeaderValue
impl<K, V, const N: usize> IntoResponse for [(K, V); N]
where
    K: TryInto<HeaderName>,
    <K as TryInto<HeaderName>>::Error: Display,
    V: TryInto<HeaderValue>,
    <V as TryInto<HeaderValue>>::Error: Display

// 特殊的 tuple 类型实现 IntoResponse, tuple 可以包含 T1..T16 个可选参数;
// Parts 为 http::response::Parts 类型, 包含响应的 header, status code
// ResponseParts 为 axum::response::ResponseParts struct 类型, 用于设置 header 和 extentions.
impl<R> IntoResponse for (Parts, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Parts, T1, R) where T1: IntoResponseParts, R: IntoResponse

// Response<()> 为 http::response::Response<()> 类型
impl<R> IntoResponse for (Response<()>, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Response<()>, T1, R) where T1: IntoResponseParts, R: IntoResponse

// 自定义响应 StatusCode 的场景
impl<R> IntoResponse for (StatusCode, R) where R: IntoResponse
impl<R, T1> IntoResponse for (StatusCode, T1, R) where T1: IntoResponseParts, R: IntoResponse

// 自定义响应 Header 的场景
impl<R, T1> IntoResponse for (T1, R) where T1: IntoResponseParts, R: IntoResponse
impl<R> IntoResponse for (R,) where R: IntoResponse


// axum::response::ResponseParts 类型
pub struct ResponseParts { /* private fields */ }
impl ResponseParts
    pub fn headers(&self) -> &HeaderMap
    pub fn headers_mut(&mut self) -> &mut HeaderMap
    pub fn extensions(&self) -> &Extensions
    pub fn extensions_mut(&mut self) -> &mut Extensions

// IntoResponseParts 用于为传入的 ResponseParts 添加 header 和 extentions
pub trait IntoResponseParts {
    type Error: IntoResponse;

    // Required method
    fn into_response_parts( self, res: ResponseParts, ) -> Result<ResponseParts, Self::Error>;
}

impl IntoResponseParts for ()
impl IntoResponseParts for Extensions
impl IntoResponseParts for HeaderMap
impl<K, V, const N: usize> IntoResponseParts for [(K, V); N]
// T1..T16 个可选参数
impl<T1> IntoResponseParts for (T1,) where T1: IntoResponseParts,

其它实现 IntoResponse 的类型,如 Json/Form/Html/Body 等:

// Extractor trait 内部的 Rejection 关联类型
impl IntoResponse for MultipartRejection
impl IntoResponse for BytesRejection
impl IntoResponse for ExtensionRejection
impl IntoResponse for FailedToBufferBody
// ...

// 其它 axum 定义的类型
impl IntoResponse for Body // struct axum::body::Body
impl IntoResponse for Redirect
impl<T> IntoResponse for Extension<T> where T: Clone + Send + Sync + 'static,
impl<T> IntoResponse for Form<T> where T: Serialize
impl<T> IntoResponse for Json<T> where T: Serialize // axum 自定义的 Json 类型 Struct axum::Json
impl<T> IntoResponse for Html<T> where T: Into<Body>

示例:

//https://docs.rs/axum/latest/axum/response/index.html
use axum::{
    Json,
    response::{Html, IntoResponse},
    http::{StatusCode, Uri, header::{self, HeaderMap, HeaderName}},
};

// `()` gives an empty response
async fn empty() {}

// String will get a `text/plain; charset=utf-8` content-type
async fn plain_text(uri: Uri) -> String {
    format!("Hi from {}", uri.path())
}

// Bytes will get a `application/octet-stream` content-type
async fn bytes() -> Vec<u8> {
    vec![1, 2, 3, 4]
}

// `Json` will get a `application/json` content-type and work with anything that implements `serde::Serialize`
async fn json() -> Json<Vec<String>> {
    Json(vec!["foo".to_owned(), "bar".to_owned()])
}

// `Html` will get a `text/html` content-type
async fn html() -> Html<&'static str> {
    Html("<p>Hello, World!</p>")
}

// `StatusCode` gives an empty response with that status code
async fn status() -> StatusCode {
    StatusCode::NOT_FOUND
}

// `HeaderMap` gives an empty response with some headers
async fn headers() -> HeaderMap {
    let mut headers = HeaderMap::new();
    headers.insert(header::SERVER, "axum".parse().unwrap());
    headers
}

// An array of tuples also gives headers
async fn array_headers() -> [(HeaderName, &'static str); 2] {
    [
        (header::SERVER, "axum"),
        (header::CONTENT_TYPE, "text/plain")
    ]
}

// Use `impl IntoResponse` to avoid writing the whole type
async fn impl_trait() -> impl IntoResponse {
    [
        (header::SERVER, "axum"),
        (header::CONTENT_TYPE, "text/plain")
    ]
}

// `(StatusCode, impl IntoResponse)` will override the status code of the response
async fn with_status(uri: Uri) -> (StatusCode, String) {
    (StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path()))
}

// Use `impl IntoResponse` to avoid having to type the whole type
async fn impl_trait(uri: Uri) -> impl IntoResponse {
    (StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path()))
}

// `(HeaderMap, impl IntoResponse)` to add additional headers
async fn with_headers() -> impl IntoResponse {
    let mut headers = HeaderMap::new();
    headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
    (headers, "foo")
}

// Or an array of tuples to more easily build the headers
async fn with_array_headers() -> impl IntoResponse {
    ([(header::CONTENT_TYPE, "text/plain")], "foo")
}

// Use string keys for custom headers
async fn with_array_headers_custom() -> impl IntoResponse {
    ([("x-custom", "custom")], "foo")
}

// `(StatusCode, headers, impl IntoResponse)` to set status and add headers
// `headers` can be either a `HeaderMap` or an array of tuples
async fn with_status_and_array_headers() -> impl IntoResponse {
    (
        StatusCode::NOT_FOUND,
        [(header::CONTENT_TYPE, "text/plain")],
        "foo",
    )
}

// `(Extension<_>, impl IntoResponse)` to set response extensions
async fn with_status_extensions() -> impl IntoResponse {
    (
        Extension(Foo("foo")),
        "foo",
    )
}

#[derive(Clone)]
struct Foo(&'static str);

// Or mix and match all the things
async fn all_the_things(uri: Uri) -> impl IntoResponse {
    let mut header_map = HeaderMap::new();
    if uri.path() == "/" {
        header_map.insert(header::SERVER, "axum".parse().unwrap());
    }

    (
        // set status code
        StatusCode::NOT_FOUND,
        // headers with an array
        [("x-custom", "custom")],
        // some extensions
        Extension(Foo("foo")),
        Extension(Foo("bar")),
        // more headers, built dynamically
        header_map,
        // and finally the body
        "foo",
    )
}

虽然文档没提, 任何实现 IntoResponse 的类型也实现了 Handler ,故可以作为 MethodRouter 的各方法的参数:

mod private {
    // Marker type for `impl<T: IntoResponse> Handler for T`
    #[allow(missing_debug_implementations)]
    pub enum IntoResponseHandler {}
}

impl<T, S> Handler<private::IntoResponseHandler, S> for T
where
    T: IntoResponse + Clone + Send + 'static,
{
    type Future = std::future::Ready<Response>;

    fn call(self, _req: Request, _state: S) -> Self::Future {
        std::future::ready(self.into_response())
    }
}

// 示例: 使用 tuple 来作为 Handler:
use axum::{
    Router,
    routing::{get, post},
    Json,
    http::StatusCode,
};
use serde_json::json;

let app = Router::new()
    // &str 实现了 IntoResponse
    .route("/", get("Hello, World!"))
    .route("/users", post(
        // tuple 的成员都实现了 IntoResponse,所以 tuple 也实现了 IntoResponse
        (StatusCode::CREATED, Json(json!({ "id": 1, "username": "alice" })),)));

extractor
#

extractor 是实现了 FromRequest/FromRequestParts trait 的类型值,可作为 Handler 闭包的输入,用于从请求中提取相关信息。

对于有多个 extractor 输入的 Handler,如 FnOnce(T1, T2, T3) ,前面的参数必须实现 FromRequestParts trait ,最后一个参数实现 FromRequest trait

  • 因为 FromRequestParts 不消耗 body,而 FromRequest 消耗 body
use axum::{
    // Request/Json/Path/Extension/Query 均是 extractor
    extract::{Request, Json, Path, Extension, Query},
    routing::post,
    http::header::HeaderMap,
    body::{Bytes, Body},
    Router,
};
use serde_json::Value;
use serde::Deserialize;
use std::collections::HashMap;

#[derive(Deserialize)]
struct CreateUser {
    email: String,
    password: String,
}

// 函数传参本质上是模式匹配赋值,所以类似于 Path((user_id, user_name)) 的 user_id/user_name 是解构后的内容。

// Path 从请求中提取路径字段(多个字段用 tuple 表示)
async fn path(Path(user_id): Path<u32>) {}
async fn path(Path((user_id, user_name)): Path<(u32, String)>) {}

// Query 从请求参数中生成对应类型
async fn query(Query(params): Query<HashMap<String, String>>) {}

// HeaderMap 包含所有请求 HTTP Headers
async fn headers(headers: HeaderMap) {}

// String 包含请求 body 的内容,确保是有效的 UTF-8
async fn string(body: String) {}

// Bytes 包含 raw 请求 Body 的内容
async fn bytes(body: Bytes) {}

// Json 将请求 body 反序列化为对应类型值(通用的为 serde_json::Value 类型)
async fn json(Json(payload): Json<Value>) {}

// Json 既实现了 Extractor 的 FromRequest trait,也实现了 IntoResponse,所以可以作为 Handler 返回值。
async fn create_user(Json(payload): Json<CreateUser>) {}

// Request 返回整个请求类型
async fn request(request: Request) {}

// Extension 从 request 的 extensions 中提取共享的 state
async fn extension(Extension(state): Extension<State>) {}

#[derive(Clone)]
struct State { /* ... */ }

let app = Router::new()
    .route("/path/:user_id/:user_name", post(path))
    .route("/query", post(query))
    .route("/string", post(string))
    .route("/bytes", post(bytes))
    .route("/json", post(json))
    .route("/request", post(request))
    .route("/extension", post(extension)
    .route("/users", post(create_user));

Handler 闭包函数的各 extractor 函数参数默认都是必须的, 否则报错,可以使用 Option 来指定可选参数:

use axum::{
    extract::{Path, Query},
    routing::get,
    Router,
};
use uuid::Uuid;
use serde::Deserialize;

let app = Router::new().route("/users/:id/things", get(get_user_things));

#[derive(Deserialize)]
struct Pagination {
    page: usize,
    per_page: usize,
}

impl Default for Pagination {
    fn default() -> Self {
        Self { page: 1, per_page: 30 }
    }
}

// Handler 可以同时使用多个 extractors
async fn get_user_things(
    // 必须的参数
    Path(user_id): Path<Uuid>,
    // 可选的参数:从请求参数中构造(反序列化)为 Pagination 类型对象
    pagination: Option<Query<Pagination>>,
) {
    let Query(pagination) = pagination.unwrap_or_default();
    // ...
}

获得 extractor 出错原因
#

使用 Result 获得 extractor 出错原因,错误类型是 extractor 实现 FromRequestParts/FromRequest 时指定的类型 Rejection 关联类型:

use axum::{
    extract::{Json, rejection::JsonRejection},
    routing::post,
    Router,
};
use serde_json::Value; // 通用的 json descrialize 后的类型

// 每种 Extractor 都定义了自己的 Rejection 类型,返回的 Result::Err 为对应类型值。
async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
    match payload {
        Ok(payload) => {
            // We got a valid JSON payload
        }
        Err(JsonRejection::MissingJsonContentType(_)) => {
            // Request didn't have `Content-Type: application/json` header
        }
        Err(JsonRejection::JsonDataError(_)) => {
            // Couldn't deserialize the body into the target type
        }
        Err(JsonRejection::JsonSyntaxError(_)) => {
            // Syntax error in the body
        }
        Err(JsonRejection::BytesRejection(_)) => {
            // Failed to extract the request body
        }
        Err(_) => {
            // `JsonRejection` is marked `#[non_exhaustive]` so match must include a catch-all case.
        }
    }
}

let app = Router::new().route("/users", post(create_user));

更复杂的获取 extractor 出错信息的例子:

use std::error::Error;
use axum::{
    extract::{Json, rejection::JsonRejection},
    response::IntoResponse,
    http::StatusCode,
};
use serde_json::{json, Value};

async fn handler(result: Result<Json<Value>, JsonRejection>,) -> Result<Json<Value>, (StatusCode, String)> {
    match result {
        // if the client sent valid JSON then we're good
        Ok(Json(payload)) => Ok(Json(json!({ "payload": payload }))),

        Err(err) => match err {
            JsonRejection::JsonDataError(err) => {
                Err(serde_json_error_response(err))
            }
            JsonRejection::JsonSyntaxError(err) => {
                Err(serde_json_error_response(err))
            }
            // handle other rejections from the `Json` extractor
            JsonRejection::MissingJsonContentType(_) => Err((
                StatusCode::BAD_REQUEST,
                "Missing `Content-Type: application/json` header".to_string(),
            )),
            JsonRejection::BytesRejection(_) => Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                "Failed to buffer request body".to_string(),
            )),
            // we must provide a catch-all case since `JsonRejection` is marked
            // `#[non_exhaustive]`
            _ => Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                "Unknown error".to_string(),
            )),
        },
    }
}

// attempt to extract the inner `serde_path_to_error::Error<serde_json::Error>`, if that succeeds we
// can provide a more specific error.
//
// `Json` uses `serde_path_to_error` so the error will be wrapped in `serde_path_to_error::Error`.
fn serde_json_error_response<E>(err: E) -> (StatusCode, String) where E: Error + 'static,
{
    if let Some(err) = find_error_source::<serde_path_to_error::Error<serde_json::Error>>(&err) {
        let serde_json_err = err.inner();
        (
            StatusCode::BAD_REQUEST,
            format!(
                "Invalid JSON at line {} column {}",
                serde_json_err.line(),
                serde_json_err.column()
            ),
        )
    } else {
        (StatusCode::BAD_REQUEST, "Unknown error".to_string())
    }
}

// attempt to downcast `err` into a `T` and if that fails recursively try and downcast `err`'s source
fn find_error_source<'a, T>(err: &'a (dyn Error + 'static)) -> Option<&'a T>
where
    T: Error + 'static,
{
    if let Some(err) = err.downcast_ref::<T>() {
        Some(err)
    } else if let Some(source) = err.source() {
        find_error_source(source)
    } else {
        None
    }
}

自定义 extractor
#

通过实现 FromRequestParts/FromRequest trait,可以自定义 extractor。

实现 FromRequestParts trait

use axum::{
    async_trait,
    extract::FromRequestParts,
    routing::get,
    Router,
    http::{
        StatusCode,
        header::{HeaderValue, USER_AGENT},
        request::Parts,
    },
};

// Extractor 惯例是 struct tuple 类型(一般是泛型类型)
struct ExtractUserAgent(HeaderValue);

// S 为 State 类型
#[async_trait]
impl<S> FromRequestParts<S> for ExtractUserAgent where S: Send + Sync,
{
    // 如果 extract 失败,返回的错误值类型。
    type Rejection = (StatusCode, &'static str);

    // 函数返回 Result,包含错误拒绝的值。
    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        if let Some(user_agent) = parts.headers.get(USER_AGENT) {
            Ok(ExtractUserAgent(user_agent.clone()))
        } else {
            Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
        }
    }
}

// 函数传参本质上是模式匹配赋值,所以 user_agent 包含解构后的内容。
async fn handler(ExtractUserAgent(user_agent): ExtractUserAgent) {
    // 使用 user_agent 值
}

let app = Router::new().route("/foo", get(handler));

实现 FromRequest

use axum::{
    async_trait,
    extract::{Request, FromRequest},
    response::{Response, IntoResponse},
    body::{Bytes, Body},
    routing::get,
    Router,
    http::{
        StatusCode,
        header::{HeaderValue, USER_AGENT},
    },
};

struct ValidatedBody(Bytes);

#[async_trait]
impl<S> FromRequest<S> for ValidatedBody where  Bytes: FromRequest<S>, S: Send + Sync,
{
    type Rejection = Response;

    // 提取出错时返回 Rejection 类型
    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let body = Bytes::from_request(req, state)
            .await
            .map_err(IntoResponse::into_response)?;

        // do validation...

        Ok(Self(body))
    }
}

async fn handler(ValidatedBody(body): ValidatedBody) {
    // 使用 body 数据
}

let app = Router::new().route("/foo", get(handler));

一个自定义类型只能实现其中一个 trait ,除非该 extractor 是其它 extractor 的包装器,这是通过对自定义类型的限界来实现的。

use axum::{
    Router,
    body::Body,
    routing::get,
    extract::{Request, FromRequest, FromRequestParts},
    http::{HeaderMap, request::Parts},
    async_trait,
};
use std::time::{Instant, Duration};

// an extractor that wraps another and measures how long time it takes to run
struct Timing<E> {
    extractor: E,
    duration: Duration,
}

// we must implement both `FromRequestParts`
#[async_trait]
impl<S, T> FromRequestParts<S> for Timing<T> where S: Send + Sync, T: FromRequestParts<S>,
{
    type Rejection = T::Rejection;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let start = Instant::now();
        let extractor = T::from_request_parts(parts, state).await?;
        let duration = start.elapsed();
        Ok(Timing {
            extractor,
            duration,
        })
    }
}

// and `FromRequest`
#[async_trait]
impl<S, T> FromRequest<S> for Timing<T> where S: Send + Sync,  T: FromRequest<S>,
{
    type Rejection = T::Rejection;

    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
        let start = Instant::now();
        let extractor = T::from_request(req, state).await?;
        let duration = start.elapsed();
        Ok(Timing {
            extractor,
            duration,
        })
    }
}

async fn handler(
    // this uses the `FromRequestParts` impl
    _: Timing<HeaderMap>,
    // this uses the `FromRequest` impl
    _: Timing<String>,
) {}

extractor 提取请求 body 时最大限制为 2MB,可使用 DefaultBodyLimit::max(size) 来自定义:

use axum::{
    Router,
    routing::post,
    body::Body,
    extract::{Request, DefaultBodyLimit},
};

let app = Router::new()
    .route("/", post(|request: Request| async {}))
    // change the default limit
    .layer(DefaultBodyLimit::max(1024));

如果要记录 extraction rejection 的日志,需要开启 axum 的 tracing feature(默认开启)和设置环境变量 RUST_LOG

RUST_LOG=info,axum::rejection=trace

常用 extractor 类型
#

axum::extract module 提供了一些常用的 extractor 类型:

  • JSON
  • Form
  • Request
  • HeaderMap
  • Extension:可以用于提取 with_state() 传入的 State 类型值。
  • ConnectInfo
  • Host
  • MatchedPath
  • MultiPart
  • NestedPath
  • OriginalUrl
  • Path
  • Query
  • RawForm
  • RawPathParams
  • RawQuery
  • State

JSON
#

实现了 FromRequest/IntoResponse trait,可以作为 Handler 的输入和输出类型:

pub struct Json<T>(pub T);

// Extractor example
use axum::{
    extract,
    routing::post,
    Router,
};
use serde::Deserialize;

#[derive(Deserialize)]
struct CreateUser {
    email: String,
    password: String,
}

// 指定 payload 为 Json<serde_json::Value> 类型时,可以反序列化为通用的 JSON 类型值。
async fn create_user(extract::Json(payload): extract::Json<CreateUser>) {
    // payload is a `CreateUser`
}

let app = Router::new().route("/users", post(create_user));

// Response example
use axum::{
    extract::Path,
    routing::get,
    Router,
    Json,
};
use serde::Serialize;
use uuid::Uuid;

#[derive(Serialize)]
struct User {
    id: Uuid,
    username: String,
}

async fn get_user(Path(user_id) : Path<Uuid>) -> Json<User> {
    let user = find_user(user_id).await;
    Json(user)
}

async fn find_user(user_id: Uuid) -> User {
    // ...
}

let app = Router::new().route("/users/:id", get(get_user));

Form
#

实现了 FromRequest/IntoResponse trait,可以作为 Handler 的输入和输出类型:

// 实现了 FromRequest
pub struct Form<T>(pub T);

use axum::Form;
use serde::Deserialize;

#[derive(Deserialize)]
struct SignUp {
    username: String,
    password: String,
}

async fn accept_form(Form(sign_up): Form<SignUp>) {
    // ...
}


// Response
use axum::Form;
use serde::Serialize;

#[derive(Serialize)]
struct Payload {
    value: String,
}

async fn handler() -> Form<Payload> {
    Form(Payload { value: "foo".to_owned() })
}

Request
#

Request 返回整个 Request 对象,具有最大化的控制能力:

async fn request(request: Request) {}

HeaderMap
#

包含所有的 Header

async fn headers(headers: HeaderMap) {}

Extension
#

Extension 实现了 FromRequestParts/Layer<S>/IntoResponse, 所以可以作为 extractor、layer middleware 和 response 数据类型;

Extension 经常在开发 layer middleware 时使用 http request extension 来向 handler 传递数据

// 作为 extractor:常用于 handers 间的共享 state 传递
use axum::{
    Router,
    Extension,
    routing::get,
};
use std::sync::Arc;

// Some shared state used throughout our application
struct State {
    // ...
}

// 因为 State 需要在所有 Handler 里共享(需要实现 Clone),所以不能使用 &mut ,故一般需要使用 Arc。
async fn handler(state: Extension<Arc<State>>) {
    // ...
}

let state = Arc::new(State { /* ... */ });
let app = Router::new().route("/", get(handler))
    .layer(Extension(state)); // Router 级别,适用于它的所有 Handler


// 作为 Response:可以在响应中添加 Extension 数据
use axum::{
    Extension,
    response::IntoResponse,
};

async fn handler() -> (Extension<Foo>, &'static str) {
    (
        Extension(Foo("foo")),
        "Hello, World!"
    )
}

#[derive(Clone)]
struct Foo(&'static str);

// Passing state from middleware to handlers
// State can be passed from middleware to handlers using request extensions:
use axum::{
    Router,
    http::StatusCode,
    routing::get,
    response::{IntoResponse, Response},
    middleware::{self, Next},
    extract::{Request, Extension},
};

#[derive(Clone)]
struct CurrentUser { /* ... */ }

async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
    let auth_header = req.headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    let auth_header = if let Some(auth_header) = auth_header {
        auth_header
    } else {
        return Err(StatusCode::UNAUTHORIZED);
    };

    if let Some(current_user) = authorize_current_user(auth_header).await {
        // insert the current user into a request extension so the handler can extract it
        req.extensions_mut().insert(current_user);
        Ok(next.run(req).await)
    } else {
        Err(StatusCode::UNAUTHORIZED)
    }
}

async fn authorize_current_user(auth_token: &str) -> Option<CurrentUser> {
    // ...
}

async fn handler(
    // extract the current user, set by the middleware
    Extension(current_user): Extension<CurrentUser>,
) {
    // ...
}

let app = Router::new()
    .route("/", get(handler))
       // Layer 先于 Handler 被执行,而且按照添加的逆序来被执行。
    .route_layer(middleware::from_fn(auth));

ConnectInfo
#

提取 client 请求信息, 需要和 Router.into_make_service_with_connect_info() 连用。

use axum::{
    extract::connect_info::{ConnectInfo, Connected},
    routing::get,
    serve::IncomingStream,
    Router,
};

let app = Router::new().route("/", get(handler));

async fn handler(
    ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
) -> String {
    format!("Hello {my_connect_info:?}")
}

通过实现 Connected<IncomingStream<'_>> trait,也可以自定义 ConnectInfo 的泛型类型:

#[derive(Clone, Debug)]
struct MyConnectInfo {
    // ...
}

impl Connected<IncomingStream<'_>> for MyConnectInfo {
    fn connect_info(target: IncomingStream<'_>) -> Self {
        MyConnectInfo {
            // ...
        }
    }
}

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();

DefaultBodyLimit
#

DefaultBodyLimit 实现了 Layer,用于限制请求 body 大小(其实并不是 extractor)。

use axum::{
    Router,
    routing::post,
    body::Body,
    extract::{Request, DefaultBodyLimit},
};

let app = Router::new()
    // change the default limit
    .layer(DefaultBodyLimit::max(1024))
    // this route has a different limit
    .route("/", post(|request: Request| async {}).layer(DefaultBodyLimit::max(1024)))
    // this route still has the default limit
    .route("/foo", post(|request: Request| async {}));

Host
#

用于解析请求主机名的提取器(extractor),实现了 FromRequestParts。

主机名按以下顺序解析:

  • Forwarded 头
  • X-Forwarded-Host 头
  • Host 头
  • 请求目标 / URI

MatchedPath
#

返回请求匹配的路由路径 extractor, 实现了 FromRequestParts,返回的 path 为 Router 原始路径字符串。

use {
    Router,
    extract::MatchedPath,
    routing::get,
};

let app = Router::new().route(
    "/users/:id",
    get(|path: MatchedPath| async move {
        let path = path.as_str();
        // `path` will be "/users/:id"
    })
);

Multipart
#

解析 multipart/form-data 请求(通常用于文件上传)的 extractor, 实现了 FromRequest,所以只能作为 handler 函数最后一个参数。

use axum::{
    extract::Multipart,
    routing::post,
    Router,
};
use futures_util::stream::StreamExt;

async fn upload(mut multipart: Multipart) {
    while let Some(mut field) = multipart.next_field().await.unwrap() {
        let name = field.name().unwrap().to_string();
        let data = field.bytes().await.unwrap();

        println!("Length of `{}` is {} bytes", name, data.len());
    }
}

let app = Router::new().route("/upload", post(upload));

multipart 表单支持 text 和 file 类型的 field,对于 text 类型使用 form 表单的方式来编码。

  1. multipart 表单,同时包含 text 和 file field:
<!doctype html>
<html>
    <head>
        <title>File Upload</title>
    </head>
    <body>
       <form action="http://localhost:8000" method="post" enctype="multipart/form-data">
         <p><input type="text" name="text" value="text default">
         <p><input type="file" name="file1">
         <p><input type="file" name="file2">
         <p><button type="submit">Submit</button>
       </form>
    </body>
</html>

POST 上传时 HTTP BODY 内容示例如下:

  • Content-Type: multipart/form-data; boundary=—————————9051914041544843365972754266
POST / HTTP/1.1
Host: localhost:8000
User-Agent: Mozilla/5.0 (X11; Ubuntu; Linux i686; rv:29.0) Gecko/20100101 Firefox/29.0
Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
Accept-Language: en-US,en;q=0.5
Accept-Encoding: gzip, deflate
Cookie: __atuvc=34%7C7; permanent=0; _gitlab_session=226ad8a0be43681acf38c2fab9497240; __profilin=p%3Dt; request_method=GET
Connection: keep-alive
Content-Type: multipart/form-data; boundary=---------------------------9051914041544843365972754266
Content-Length: 554

-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="text"

text default
-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="file1"; filename="a.txt"
Content-Type: text/plain

Content of a.txt.

-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="file2"; filename="a.html"
Content-Type: text/html

<!DOCTYPE html><title>Content of a.html.</title>

NestedPath
#

获得匹配 route 的嵌套 path,实现了 FromRequestParts:

use axum::{
    Router,
    extract::NestedPath,
    routing::get,
};

let api = Router::new().route(
    "/users",
    get(|path: NestedPath| async move {
        // `path` will be "/api" because thats what this router is nested at when we build `app`
        let path = path.as_str();
    })
);

let app = Router::new().nest("/api", api);

OriginalUri
#

获得请求的原始 URI ,而不管是否嵌套:

use axum::{
    routing::get,
    Router,
    extract::OriginalUri,
    http::Uri
};

let api_routes = Router::new()
    .route(
        "/users",
        get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
            // `uri` is `/users`
            // `original_uri` is `/api/users`
        }),
    );

let app = Router::new().nest("/api", api_routes);

Path
#

将请求 path 中的参数使用 serde 进行反序列化,多个路径参数用 tuple 来表示:

use axum::{
    extract::Path,
    routing::get,
    Router,
};
use uuid::Uuid;

async fn users_teams_show(
    Path((user_id, team_id)): Path<(Uuid, Uuid)>,
) {
    // ...
}

let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));

Query
#

将请求参数 Deserialize 到一个 struct type。(如果参数是可选的,需要使用 Option 类型)

use axum::{
    extract::Query,
    routing::get,
    Router,
};
use serde::Deserialize;

#[derive(Deserialize)]
struct Pagination {
    page: usize,
    per_page: usize,
}

// This will parse query strings like `?page=2&per_page=30` into `Pagination` structs.
async fn list_things(pagination: Query<Pagination>) {
    let pagination: Pagination = pagination.0;

    // ...
}

let app = Router::new().route("/list_things", get(list_things));

RawForm
#

Form 的类型需要实现 serde::Deserialize trait,而 RawForm 则不会对请求 body 进行反序列化,直接提取原始的 form 数据字符串。

use axum::{
    extract::RawForm,
    routing::get,
    Router
};

async fn handler(RawForm(form): RawForm) {}

let app = Router::new().route("/", get(handler));

RawPathParams
#

从 URL 中提取路径参数,但是不进行反序列化:

pub struct RawPathParams(/* private fields */);

impl<'a> IntoIterator for &'a RawPathParams
    type Item = (&'a str, &'a str)

use axum::{
    extract::RawPathParams,
    routing::get,
    Router,
};

async fn users_teams_show(params: RawPathParams) {
    for (key, value) in &params {
        println!("{key:?} = {value:?}");
    }
}

let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));

RawQuery
#

从请求中提取原始的查询字符串:

pub struct RawQuery(pub Option<String>);

// Extractor that extracts the raw query string, without parsing it.
// Example

use axum::{
    extract::RawQuery,
    routing::get,
    Router,
};
use futures_util::StreamExt;

async fn handler(RawQuery(query): RawQuery) {
    // ...
}

let app = Router::new().route("/users", get(handler));

State
#

State 是 Router 中使用的全局对象,不能使用 &mut,需要实现 Clone,所以一般使用 Arc<Mutex<_>> 具有内部可变性的类型:

use axum::{Router, routing::get, extract::State};

// the application state
//
// here you can put configuration, database connection pools, or whatever state you need
//
// see "When states need to implement `Clone`" for more details on why we need `#[derive(Clone)]` here.
#[derive(Clone)]
struct AppState {}

let state = AppState {};

// create a `Router` that holds our state
let app = Router::new()
    .route("/", get(handler))
    // provide the state so the router can access it
    .with_state(state);

async fn handler(
    // access the state via the `State` extractor
    // extracting a state of the wrong type results in a compile error
    State(state): State<AppState>,
) {
    // use `state`...
}

substate
#

State 只能使用一种类型,但是可以使用 FromRef 来提取“子状态”:

use axum::{Router, routing::get, extract::{State, FromRef}};

// the application state
#[derive(Clone)]
struct AppState {
    // that holds some api specific state
    api_state: ApiState,
}

// the api specific state
#[derive(Clone)]
struct ApiState {}

// support converting an `AppState` in an `ApiState`
impl FromRef<AppState> for ApiState {
    fn from_ref(app_state: &AppState) -> ApiState {
        app_state.api_state.clone()
    }
}

let state = AppState {
    api_state: ApiState {},
};

let app = Router::new()
    .route("/", get(handler))
    .route("/api/users", get(api_users))
    .with_state(state);

async fn api_users(
    // access the api specific state
    State(api_state): State<ApiState>, // 自动使用 ApiState 实现的 FromRef 来将 AppState 转换为 ApiState
) {
}

async fn handler(
    // we can still access to top level state
    State(state): State<AppState>,
) {
}

WebSocketUpgrade
#

为已经建立连接的客户端升级为 WebSocket 连接,实现了 FromRequestParts:

use axum::{
    extract::ws::{WebSocketUpgrade, WebSocket},
    routing::get,
    response::{IntoResponse, Response},
    Router,
};

let app = Router::new().route("/ws", get(handler));

async fn handler(ws: WebSocketUpgrade) -> Response {
    ws.on_upgrade(handle_socket)
}

async fn handle_socket(mut socket: WebSocket) {
    while let Some(msg) = socket.recv().await {
        let msg = if let Ok(msg) = msg {
            msg
        } else {
            // client disconnected
            return;
        };

        if socket.send(msg).await.is_err() {
            // client disconnected
            return;
        }
    }
}

// If you need to read and write concurrently from a WebSocket you can use StreamExt::split:
use axum::{Error, extract::ws::{WebSocket, Message}};
use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};

async fn handle_socket(mut socket: WebSocket) {
    let (mut sender, mut receiver) = socket.split();

    tokio::spawn(write(sender));
    tokio::spawn(read(receiver));
}

async fn read(receiver: SplitStream<WebSocket>) {
    // ...
}

async fn write(sender: SplitSink<WebSocket, Message>) {
    // ...
}

完整的 websocket 示例:

// https://docs.shuttle.rs/examples/axum-websockets

use std::{sync::Arc, time::Duration};

use axum::{
    extract::{
        ws::{Message, WebSocket},
        WebSocketUpgrade,
    },
    response::IntoResponse,
    routing::get,
    Extension, Router,
};
use chrono::{DateTime, Utc};
use futures::{SinkExt, StreamExt};
use serde::Serialize;
use shuttle_axum::ShuttleAxum;
use tokio::{
    sync::{watch, Mutex},
    time::sleep,
};
use tower_http::services::ServeDir;

struct State {
    clients_count: usize,
    rx: watch::Receiver<Message>,
}

const PAUSE_SECS: u64 = 15;
const STATUS_URI: &str = "https://api.shuttle.rs";

#[derive(Serialize)]
struct Response {
    clients_count: usize,
    #[serde(rename = "dateTime")]
    date_time: DateTime<Utc>,
    is_up: bool,
}

#[shuttle_runtime::main]
async fn axum() -> ShuttleAxum {
    let (tx, rx) = watch::channel(Message::Text("{}".to_string()));

    let state = Arc::new(Mutex::new(State {
        clients_count: 0,
        rx,
    }));

    // Spawn a thread to continually check the status of the api
    let state_send = state.clone();
    tokio::spawn(async move {
        let duration = Duration::from_secs(PAUSE_SECS);

        loop {
            let is_up = reqwest::get(STATUS_URI).await;
            let is_up = is_up.is_ok();

            let response = Response {
                clients_count: state_send.lock().await.clients_count,
                date_time: Utc::now(),
                is_up,
            };
            let msg = serde_json::to_string(&response).unwrap();

            if tx.send(Message::Text(msg)).is_err() {
                break;
            }

            sleep(duration).await;
        }
    });

    let router = Router::new()
        .route("/websocket", get(websocket_handler))
        .nest_service("/", ServeDir::new("static"))
        .layer(Extension(state));

    Ok(router.into())
}

async fn websocket_handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<Mutex<State>>>,) -> impl IntoResponse {
    ws.on_upgrade(|socket| websocket(socket, state))
}

async fn websocket(stream: WebSocket, state: Arc<Mutex<State>>) {
    // By splitting we can send and receive at the same time.
    let (mut sender, mut receiver) = stream.split();

    let mut rx = {
        let mut state = state.lock().await;
        state.clients_count += 1;
        state.rx.clone()
    };

    // This task will receive watch messages and forward it to this connected client.
    let mut send_task = tokio::spawn(async move {
        while let Ok(()) = rx.changed().await {
            let msg = rx.borrow().clone();

            if sender.send(msg).await.is_err() {
                break;
            }
        }
    });

    // This task will receive messages from this client.
    let mut recv_task = tokio::spawn(async move {
        while let Some(Ok(Message::Text(text))) = receiver.next().await {
            println!("this example does not read any messages, but got: {text}");
        }
    });

    // If any one of the tasks exit, abort the other.
    tokio::select! {
        _ = (&mut send_task) => recv_task.abort(),
        _ = (&mut recv_task) => send_task.abort(),
    };

    // This client disconnected
    state.lock().await.clients_count -= 1;
}

state
#

State 可以为 Handler 共享全局数据或状态,如数据库连接池对象或其它 Client 等。

三种共享 State 的方式:

  1. 使用 State extractor;
  2. 使用 Extensions extractor;
  3. 使用闭包捕获机制;

State 的限界是 Clone + Send + Sync + 'static, 要获得所有权和支持多线程环境,所以一般使用 Arc<Mutex<>> 包裹的支持内部可变性的对象类型。

Router<S> 的 S 泛型参数表示 State 类型:

pub struct Router<S = ()> { /* private fields */ }

// Router<S> 的方法,S 需要满足 Clone + Send + Sync + 'static
impl<S> Router<S> where S: Clone + Send + Sync + 'static
  // with_state<S2> 是泛型方法,如果未显式指定 S2 类型,则 Rust 根据上下文自动推断,
  // 例如,后续调用该 Router 的 into_make_service() 方法时,S2 自动推断为 ();
  pub fn with_state<S2>(self, state: S) -> Router<S2>

// Router 等效于 Router<S=()>,下面两个方法只在 Router<()> 类型上定义
impl Router
  pub fn into_make_service(self) -> IntoMakeService<Self>
  pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>

State Extractor
#

可以在 Router、MethodRouter 和 Handler 三个层次上,通过 with_state() 方法添加 State 值:

// Router:
// pub fn with_state<S2>(self, state: S) -> Router<S2>

use axum::{
    extract::State,
    routing::get,
    Router,
};
use std::sync::Arc;

struct AppState {
    // ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
    .route("/", get(handler))
    .with_state(shared_state);

async fn handler(
    State(state): State<Arc<AppState>>,
) {
    // ...
}

当从函数返回 Router 时,一般建议不直接在函数内设置 State,而是分两步:

  1. 定义函数返回 Router<S> 类型;
  2. 调用该函数返回对象的 with_state() 方法;

这是由于 Router::into_make_service() 是在 Router<()> 类型上实现。

下面的例子报错,是因为 Router<AppState> 没有提供 Router::into_make_service() 方法:

// This wont work because we're returning a `Router<AppState>` i.e. we're saying we're still missing an `AppState`
fn routes(state: AppState) -> Router<AppState> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

// app 是 Router<AppState> 类型
let app = routes(AppState {});
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// 出错的情况:We can only call `Router::into_make_service` on a `Router<()>` but `app` is a `Router<AppState>`
axum::serve(listener, app).await.unwrap();

解决办法:返回 Router<()> 类型:

// We've provided all the state necessary so return `Router<()>`
fn routes(state: AppState) -> Router<()> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

// app 是 Router<()> 类型
let app = routes(AppState {});

// We can now call `Router::into_make_service`
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();

Router::with_state<S2>::() -> Router<S2> 返回新 Router<S2> 并不总是 Router<()> 类型 ,取决于后续使用返回的 Router<S2> 时编译器对 S2 的推断。

例如,如果对返回的对象调用 into_make_service() 方法,由于该方法是在 Router<()> 上实现的,所以 S2 推断为 ():

let router: Router<AppState> = Router::new()
    .route("/", get(|_: State<AppState>| async {}));

// When we call `with_state` we're able to pick what the next missing state type is.
// Here we pick `String`.
let string_router: Router<String> = router.with_state(AppState {});

// That allows us to add new routes that uses `String` as the state type
let string_router = string_router
    .route("/needs-string", get(|_: State<String>| async {}));

// Provide the `String` and choose `()` as the new missing state.
let final_router: Router<()> = string_router.with_state("foo".to_owned());

// Since we have a `Router<()>` we can run it.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, final_router).await.unwrap();

如果确实需要在返回 Router 的函数内设置 State,则返回的 State 不要加泛型参数或指定返回类型是 Router<()>:

// Don't return `Router<AppState>`
fn routes(state: AppState) -> Router {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let routes = routes(AppState {});

let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

如果函数返回的 Routernest() 方法中使用,则函数返回的 Router 需要使用无限界的泛型参数,后续由编译器自动推断(如在 axum::serve() 中使用时,自动推断为 () 类型):

  • 如果不加泛型参数 S, 则表示返回的是 Router<()> 类型,可能与调用该方法的 Router 对象泛型参数类型不符。
// Router 的 nest() 方法传入的 router 对象的 S 类型和自身 S 一致。
impl<S> Router<S> where S: Clone + Send + Sync + 'static,
    pub fn nest(self, path: &str, router: Router<S>) -> Self


// 示例:
fn routes<S>(state: AppState) -> Router<S> {
    Router::new()
        .route("/", get(|_: State<AppState>| async {}))
        .with_state(state)
}

let routes = Router::new().nest("/api", routes(AppState {}));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();

例如:


// A router that _needs_ an `AppState` to handle requests
let router: Router<AppState> = Router::new().route("/", get(|_: State<AppState>| async {}));

// Once we call `Router::with_state` the router isn't missing the state anymore, because we just provided it
//
// Therefore the router type becomes `Router<()>`, i.e a router that is not missing any state
let router: Router<()> = router.with_state(AppState {});

// Only `Router<()>` has the `into_make_service` method.
//
// You cannot call `into_make_service` on a `Router<AppState>` because it is still missing an `AppState`.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();

性能优化:如果需要一个实现 Service 的 Router,但是不需要任何 State,也建议调用 .with_state(()) 方法,这样可以让 axum 更新内部状态,减少分配内存,从而提升性能:

use axum::{Router, routing::get};

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    // even though we don't need any state, call `with_state(())` anyway
    .with_state(());

By 请求的 State:Extensions
#

Router 层次添加的 State 可以被所有该 Router 的所有请求使用。

如果要根据 Request 来生成特定请求相关的 State,例如从中间件生成的认证授权数据,则需要使用 Extension。

Extension 是使用 http request extension 向 Handler 传递 state 的机制。

Extension 实现了 FromRequestParts 和 Layer 和 IntoResponse, 所以可以作为 extractor、layer middleware 和响应数据类型;

具体参考前文 Extension 章节。

使用闭包传递 State
#

State 也可以直接被 Handler 闭包捕获:

use axum::{
    Json,
    extract::{Extension, Path},
    routing::{get, post},
    Router,
};
use std::sync::Arc;
use serde::Deserialize;

struct AppState {
    // ...
}

let shared_state = Arc::new(AppState { /* ... */ });

let app = Router::new()
    .route(
        "/users",
        post({
            let shared_state = Arc::clone(&shared_state);
            move |body| create_user(body, shared_state)
        }),
    )
    .route(
        "/users/:id",
        get({
            let shared_state = Arc::clone(&shared_state);
            move |path| get_user(path, shared_state)
        }),
    );

async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
    // ...
}

async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
    // ...
}

#[derive(Deserialize)]
struct CreateUserPayload {
    // ...
}

Layer 中间件
#

可以在 Router、MethodRouter 和 Handler 三个层次上,通过 layer()/route_layer() 方法添加 Layer 中间件:

  • 整个 routers:Router::layer()Router::route_layer()
  • 单个 method router:MethodRouter::layer()MethodRouter::route_layer()
  • 单个 handler:Handler::layer()

Layer 封装了一个 Service,并返回一个新的 Service:

pub trait Layer<S> {
    type Service;

    // S 为传入的 Service
    fn layer(&self, inner: S) -> Self::Service;
}

Layer 是在 Handler 之前执行的,故一般用来实现中间件逻辑:

impl<S> Router<S> where S: Clone + Send + Sync + 'static

    pub fn layer<L>(self, layer: L) -> Router<S>
    where
        // Route 实现了 Service trait,它的 Request 类型是 http::request::Request,响应类型是 http::response::Response
        L: Layer<Route> + Clone + Send + 'static,
        L::Service: Service<Request> + Clone + Send + 'static,
        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
        <L::Service as Service<Request>>::Future: Send + 'static

    pub fn route_layer<L>(self, layer: L) -> Self
    where
        L: Layer<Route> + Clone + Send + 'static,
        L::Service: Service<Request> + Clone + Send + 'static,
        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
        <L::Service as Service<Request>>::Future: Send + 'static,

route_layer()layer() 差异:前者只在 router 匹配时才执行 layer 逻辑,而后者无论如何都执行 layer 逻辑。

Route 实现了 Service trait,它的 Request 类型是 http::request::Request,响应类型是 http::response::Response:

// Route
impl<B, E> Service<Request<B>> for Route<E>
where
    // B 是 Trait http_body::Body,其中 Data 是 bytes::Bytes
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>

    // Responses given by the service. Struct http::response::Response 类型
    type Response = Response<Body>
    type Error = E
    type Future = RouteFuture<E>

使用 Router.layer() 添加的 Layer 按相反的顺序被依次调用,最后调用 Handler, 所以可以在 layer() 中实现调用 Handler 前和后的处理能力:

use axum::{routing::get, Router};
async fn handler() {}
let app = Router::new()
    .route("/", get(handler))
    .layer(layer_one)
    .layer(layer_two)
    .layer(layer_three); // 最先被调用执行

// 用 layer() 添加的 layer 按相反的顺序被一次调用,最后调用 handler

        requests
           |
           v
+----- layer_three -----+
| +---- layer_two ----+ |
| | +-- layer_one --+ | |
| | |               | | |
| | |    handler    | | |
| | |               | | |
| | +-- layer_one --+ | |
| +---- layer_two ----+ |
+----- layer_three -----+
           |
           v
        responses

使用 tower::ServiceBuilderlayer() 添加的中间件,按照添加的顺序来执行。

  • 建议使用 ServiceBuilder 来创建含多个 Layer 的 Layer:
use tower::ServiceBuilder;
use axum::{routing::get, Router};

async fn handler() {}

let app = Router::new()
    .route("/", get(handler))
    .layer(
        ServiceBuilder::new()
            .layer(layer_one)
            .layer(layer_two)
            .layer(layer_three), // layer() 方法返回的是 Layer 对象
            //.service(handerl) // 调用 service() 方法后,返回 Service 对象
    );

tower_http crate 提供的 Layer
#

  • TraceLayer :用于 tracing/logging;
  • CorsLayer :用于处理 CORS;
  • CompressionLayer :用于自动压缩响应;
  • RequestIdLayerPropagateRequestIdLayer :用于设置和传播 request ids;
  • TimeoutLayer :用于超时控制;
use axum::{
    routing::get,
    Router,
};
use tower_http::validate_request::ValidateRequestHeaderLayer;

let app = Router::new().route(
    "/foo",
    get(|| async {})
    .route_layer(ValidateRequestHeaderLayer::bearer("password"))
);

// RequestIdLayer 示例
use http::{Request, Response, header::HeaderName};
use tower::{Service, ServiceExt, ServiceBuilder};
use tower_http::request_id::{
    SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
};
use http_body_util::Full;
use bytes::Bytes;
use std::sync::{Arc, atomic::{AtomicU64, Ordering}};

// A `MakeRequestId` that increments an atomic counter
#[derive(Clone, Default)]
struct MyMakeRequestId {
    counter: Arc<AtomicU64>,
}

impl MakeRequestId for MyMakeRequestId {
    fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
        let request_id = self.counter
            .fetch_add(1, Ordering::SeqCst)
            .to_string()
            .parse()
            .unwrap();

        Some(RequestId::new(request_id))
    }
}

let x_request_id = HeaderName::from_static("x-request-id");

let mut svc = ServiceBuilder::new()
    // set `x-request-id` header on all requests
    .layer(SetRequestIdLayer::new(
        x_request_id.clone(),
        MyMakeRequestId::default(),
    ))
    // propagate `x-request-id` headers from request to response
    .layer(PropagateRequestIdLayer::new(x_request_id))
    .service(handler);

let request = Request::new(Full::default());
let response = svc.ready().await?.call(request).await?;

assert_eq!(response.headers()["x-request-id"], "0");

另一个例子: 在 tracing log 中记录 request ID:

use tower_http::{
    ServiceBuilderExt,
    trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
};

let svc = ServiceBuilder::new()
    // make sure to set request ids before the request reaches `TraceLayer`
    .set_x_request_id(MyMakeRequestId::default())
    // log requests and responses
    .layer(
        TraceLayer::new_for_http()
            .make_span_with(DefaultMakeSpan::new().include_headers(true))
            .on_response(DefaultOnResponse::new().include_headers(true))
    )
    // propagate the header to the response before the response reaches `TraceLayer`
    .propagate_x_request_id()
    .service(handler);

使用 tower_http::ServiceBuilderExt 提供的方法,可以大大简化创建和使用这些中间件 Layer:

use tower_http::ServiceBuilderExt;

let mut svc = ServiceBuilder::new()
    .set_x_request_id(MyMakeRequestId::default())
    .propagate_x_request_id() // 这些简化方法返回 Layer
    .service(handler); // 调用 service 方法后,返回 Service

let request = Request::new(Full::default());
let response = svc.ready().await?.call(request).await?;

assert_eq!(response.headers()["x-request-id"], "0");

创建自定义 Layer
#

有 4 种方式:

  1. axum::middleware::from_fn/from_fn_with_state : 使用闭包创建
  2. axum::middleware::from_extractor : 从现有的 Extractor 创建
  3. tower 的 combinators,如:
    1. ServiceBuilder::map_request
    2. ServiceBuilder::map_response
    3. ServiceBuilder::then
    4. ServiceBuilder::and_then
  4. tower::Service 和 Pin<Box<dyn Future>>: 最灵活(也是最复杂)的创建方式

使用 from_fn()
#

使用 axum::middleware::from_fn()/from_fn_with_state() 创建 Layer, 对于传入的闭包 f 有如下要求:

  • 是异步闭包函数类型;
  • 传入 0 个或多个 FromRequestParts extractors,传入一个 FromRequest extractor,如 Request,作为倒数第二个参数;
  • 使用 Next 作为最后一个参数;
  • 返回值需要实现 IntoResponse;
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T>

// 实例:
use axum::{
    Router,
    http,
    routing::get,
    response::Response,
    middleware::{self, Next},
    extract::Request,
};

async fn my_middleware(request: Request, next: Next,) -> Response {
    // do something with `request`...
    let response = next.run(request).await;
    // do something with `response`...
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(middleware::from_fn(my_middleware));

如果要中间件函数要使用 State,则使用 axum::middleware::from_fn_with_state() 来创建:

pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T>

use axum::{
    Router,
    http::StatusCode,
    routing::get,
    response::{IntoResponse, Response},
    middleware::{self, Next},
    extract::{Request, State},
};

#[derive(Clone)]
struct AppState { /* ... */ }

async fn my_middleware(
    State(state): State<AppState>,
    // you can add more extractors here but the last extractor must implement `FromRequest` which `Request` does
    request: Request,
    next: Next,
) -> Response {
    // do something with `request`...
    let response = next.run(request).await;
    // do something with `response`...
    response
}

let state = AppState { /* ... */ };

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
    .with_state(state);

使用 from_extractor
#

使用 axum::middleware::from_extractor()/from_extractor_with_state() 从一个 Extractor 类型创建中间件:

  • 一般用于 validate 请求,可以复用已有的 extractor type 类型;
  • 如果 extractor 执行成功则继续处理, 否则出错返回;
  • 如果消耗 body, 则后续的 Router Service 在处理时获得的是空 body;
// E 是 Extractor 类型
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()>

// 实例: 创建一个 RequireAuth extractor 来验证 Authorization 头部:

use axum::{
    extract::FromRequestParts,
    middleware::from_extractor,
    routing::{get, post},
    Router,
    http::{header, StatusCode, request::Parts},
};
use async_trait::async_trait;

// An extractor that performs authorization.
struct RequireAuth;

#[async_trait]
impl<S> FromRequestParts<S> for RequireAuth where S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let auth_header = parts
            .headers
            .get(header::AUTHORIZATION)
            .and_then(|value| value.to_str().ok());

        match auth_header {
            Some(auth_header) if token_is_valid(auth_header) => {
                Ok(Self)
            }
            _ => Err(StatusCode::UNAUTHORIZED),
        }
    }
}

fn token_is_valid(token: &str) -> bool {}

async fn handler() {}
async fn other_handler() {}

let app = Router::new()
    .route("/", get(handler))
    .route("/foo", post(other_handler))
    // The extractor will run before all routes
    .route_layer(from_extractor::<RequireAuth>());

使用 map_request/map_response
#

使用 tower crate 提供的一些工具函数来创建 Layer, 它们一般用来做简单的 request 或 response 转换,返回的类型实现了 Layer:

  • map_request()
  • map_request_with_state()
  • map_response()
  • map_response_with_state()

map_request() 用于对请求进行转换,它的参数是闭包 F,输入是 T1..T16 个 extractor,返回转换后的 Request 类型(如果为 Result 则值为 Err 时会提前拒绝请求):

  • Request
  • Result<Request, E> where E: IntoResponse
pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T>

pub struct MapRequestLayer<F, S, T> { /* private fields */ }

impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T> where    F: Clone,    S: Clone,
    type Service = MapRequest<F, S, I, T>
    fn layer(&self, inner: I) -> Self::Service

pub struct MapRequest<F, S, I, T> { /* private fields */ }

// axum 为闭包函数 F 实现了 Service trait:
// F 的参数数量从 T1 到 T16 不等,分别对应 1 到 16 个 extractor 参数。
impl<F, Fut, S, I, B, T1> Service<Request<B>> for MapRequest<F, S, I, (T1,)>
where
    F: FnMut(T1) -> Fut + Clone + Send + 'static,
    T1: FromRequest<S> + Send,
    Fut: Future + Send + 'static,
    Fut::Output: IntoMapRequestResult<B> + Send + 'static,
    I: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
    I::Response: IntoResponse,
    I::Future: Send + 'static,
    B: HttpBody<Data = Bytes> + Send + 'static,
    B::Error: Into<BoxError>,
    S: Clone + Send + Sync + 'static,

    type Response = Response<Body>
    type Error = Infallible
    type Future = ResponseFuture
    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>
    fn call(&mut self, req: Request<B>) -> Self::Future

// 闭包函数 F 的返回值需要实现该 IntoMapRequestResult trait
pub trait IntoMapRequestResult<B>: Sealed<B> {
    // Required method
    fn into_map_request_result(self) -> Result<Request<B>, Response>;
}

// axum 为 Request<B> 和 Result<Request<B>, E> 实现了 IntoMapRequestResult trait
impl<B> IntoMapRequestResult<B> for Request<B>
impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E> where E: IntoResponse,


// 示例:
use axum::{
    Router,
    http::{Request, StatusCode},
    routing::get,
    middleware::map_request,
};

async fn set_header<B>(mut request: Request<B>) -> Request<B> {
    request.headers_mut().insert("x-foo", "foo".parse().unwrap());
    request
}

async fn handler<B>(request: Request<B>) {
    // `request` will have an `x-foo` header
}

let app = Router::new()
    .route("/", get(handler))
    .layer(map_request(set_header));

// 返回 Result 类型的情况
async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
    let auth_header = request.headers()
        .get(http::header::AUTHORIZATION)
        .and_then(|header| header.to_str().ok());

    match auth_header {
        Some(auth_header) if token_is_valid(auth_header) => Ok(request),
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

fn token_is_valid(token: &str) -> bool {
    // ...
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .route_layer(map_request(auth));

// map_request() 的函数的输入除了 Request 外,还可以包含其它 extractor
use axum::{
    Router,
    routing::get,
    middleware::map_request,
    extract::Path,
    http::Request,
};
use std::collections::HashMap;

async fn log_path_params<B>(
    Path(path_params): Path<HashMap<String, String>>,
    request: Request<B>,
) -> Request<B> {
    tracing::debug!(?path_params);
    request
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_request(log_path_params));

map_response() 用于对响应进行转换:传入的闭包 F,输入是 T1..T16 个 extractor,返回转换后的响应类型(可以是任意实现 IntoResponse 的类型);

pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T>

pub struct MapResponseLayer<F, S, T> { /* private fields */ }

impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T> where    F: Clone,    S: Clone,
    type Service = MapResponse<F, S, I, T>
    fn layer(&self, inner: I) -> Self::Service

pub struct MapResponse<F, S, I, T> { /* private fields */ }

impl<F, Fut, S, I, B, ResBody> Service<Request<B>> for MapResponse<F, S, I, ()>
where
    F: FnMut(Response<ResBody>) -> Fut + Clone + Send + 'static,
    Fut: Future + Send + 'static,
    Fut::Output: IntoResponse + Send + 'static,
    I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible> + Clone + Send + 'static,
    I::Future: Send + 'static,
    B: Send + 'static,
    ResBody: Send + 'static,
    S: Clone + Send + Sync + 'static,
// 闭包 F 的参数数量从 T1 到 T16 不等,分别对应 1 到 16 个 extractor 参数。
impl<F, Fut, S, I, B, ResBody, T1, T2> Service<Request<B>> for MapResponse<F, S, I, (T1, T2)>
where
    F: FnMut(T1, T2, Response<ResBody>) -> Fut + Clone + Send + 'static,
    T1: FromRequestParts<S> + Send,
    T2: FromRequestParts<S> + Send,
    Fut: Future + Send + 'static,
    Fut::Output: IntoResponse + Send + 'static,
    I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible> + Clone + Send + 'static,
    I::Future: Send + 'static,
    B: Send + 'static,
    ResBody: Send + 'static,
    S: Clone + Send + Sync + 'static,

// 示例:
use axum::{
    Router,
    routing::get,
    middleware::map_response,
    response::Response,
};

async fn set_header<B>(mut response: Response<B>) -> Response<B> {
    response.headers_mut().insert("x-foo", "foo".parse().unwrap());
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(set_header));

// 和 map_request() 类似,也可以使用 extractor 参数
use axum::{
    Router,
    routing::get,
    middleware::map_response,
    extract::Path,
    response::Response,
};
use std::collections::HashMap;

async fn log_path_params<B>(
    Path(path_params): Path<HashMap<String, String>>,
    response: Response<B>,
) -> Response<B> {
    tracing::debug!(?path_params);
    response
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(log_path_params));

// 可以返回任何实现 IntoResponse trait 的类型对象
use axum::{
    Router,
    routing::get,
    middleware::map_response,
    response::{Response, IntoResponse},
};
use std::collections::HashMap;

async fn set_header(response: Response) -> impl IntoResponse {
    (
        [("x-foo", "foo")],
        response,
    )
}

let app = Router::new()
    .route("/", get(|| async { /* ... */ }))
    .layer(map_response(set_header));

使用 tower::Service 和 Pin<Box>
#

使用 tower::ServicePin<Box<dyn Future>> 实现 Layer 及关联的 Service,具有最大的灵活性:

use axum::{
    response::Response,
    body::Body,
    extract::Request,
};
use futures_util::future::BoxFuture;
use tower::{Service, Layer};
use std::task::{Context, Poll};

// 定义 Layer 时,一般还要为 Layer 定义一个关联的 Service
#[derive(Clone)]
struct MyLayer;

impl<S> Layer<S> for MyLayer {
    type Service = MyMiddleware<S>; // 关联的 Service 类型

    fn layer(&self, inner: S) -> Self::Service {
        MyMiddleware { inner }
    }
}

#[derive(Clone)]
struct MyMiddleware<S> {
    inner: S,
}

// Service 的 request 必须是 http::request::Request 类型,response 必须是 http::response::Response, 才能满足 Routing.layer() 的要求。
impl<S> Service<Request> for MyMiddleware<S>
where
    S: Service<Request, Response = Response> + Send + 'static,
    S::Future: Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;

    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request) -> Self::Future {
        let future = self.inner.call(request);
        Box::pin(async move {
            let response: Response = future.await?;
            Ok(response)
        })
    }
}

error_handling
#

对于 Result<T, E> 类型,如果 T 和 E 都实现 IntoResponse 时,则该 Result 也实现 IntoResponse:

// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse

// axum::response::ErrorResponse 实现了 From<IntoResponse>, 可以从 IntoResponse 转换为自身的类型
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse

impl<T, U> IntoResponse for Chain<T, U> where
    T: Buf + Unpin + Send + 'static,
    U: Buf + Unpin + Send + 'static,

例如 axum::http::StatusCode 实现了 IntoResponse trait,故可以使用它作为 Result 的返回值:

use axum::http::StatusCode;

async fn handler() -> Result<String, StatusCode> {
    // ...
}

如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse;

// https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs
use axum::{
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::get,
    Router,
};

#[tokio::main]
async fn main() {
    let app = Router::new().route("/", get(handler));
    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap();
    println!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}


// Handler 返回的 Result 会被 IntoResponse 转换为  Response 返回给 client
async fn handler() -> Result<(), AppError> {
    try_thing()?;
    Ok(())
}

fn try_thing() -> Result<(), anyhow::Error> {
    anyhow::bail!("it failed!")
}

// Make our own error that wraps `anyhow::Error`.
struct AppError(anyhow::Error);

// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
    fn into_response(self) -> Response {
	// axum 为 (StatusCode, String) 实现了 IntoRespose
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Something went wrong: {}", self.0),
        )
            .into_response()
    }
}

// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
    E: Into<anyhow::Error>,
{
    fn from(err: E) -> Self {
        Self(err.into())
    }
}

axum 在使用 Service 时,要求 Service 关联的 Error 类型为 Infallible,即不可能发生错误。

例如,Router.route_service() 使用 Service 作为请求的 Handler,该 Service 的约束是 Service<Request, Error = Infallible> + Clone + Send + 'static

// struct Router 的 route_service() 方法的 service 约束:
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
    // Error 是 Infallible 类型,即不可能返回错误
    T: Service<Request, Error = Infallible> + Clone + Send + 'static,
    // Response 必须实现 IntoRsponse
    T::Response: IntoResponse,
    T::Future: Send + 'static,


// 示例:
use axum::{
    Router,
    body::Body,
    routing::{any_service, get_service},
    extract::Request,
    http::StatusCode,
    error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;

let app = Router::new()
    .route(
        "/",
        // service_fn() 从闭包函数创建一个 Service,要求闭包函数的输入为 Request,输出为 Result。
        // 为了满足 any_service 的要求,Result 的 Error 必须为 Infallible .
        any_service(service_fn(|_: Request| async {
            let res = Response::new(Body::from("Hi from `GET /`"));
            Ok::<_, Infallible>(res)
        }))
    )
    .route_service(
        "/foo",
        service_fn(|req: Request| async move {
            let body = Body::from(format!("Hi from `{} /foo`", req.method()));
            let res = Response::new(body);
            Ok::<_, Infallible>(res)
        })
    )
    .route_service(
        "/static/Cargo.toml",
        ServeFile::new("Cargo.toml"),
    );

而一般通过 tower::service_fn(fn) 创建的 Service 约束是 FnMut(Request) -> Future<Output = Result<R, E>> ,其中 E 可能不是 Infallible 类型,导致类型不匹配。

为了能在 Router.route_service() 中使用 tower::service_fn(fn) 创建的返回 Result<R, E> 的 Service,axum 提供了 axum::error_handling::HandleError 来将它们转换为 Response。

HandleError 将返回 Result ErrorService 转换为 Service<Request, Error=Infallible> ,它是通过传入一个闭包来将 Result Err 转换为 IntoResponse:

use axum::{
    Router,
    body::Body,
    http::{Request, Response, StatusCode},
    error_handling::HandleError,
};

async fn thing_that_might_fail() -> Result<(), anyhow::Error> {
    // ...
}

// 使用 service_fn() 将异步函数转换为实现 Service 的 ServiceFn 对象, 而 ServiceFn 在实现 Service 时对
// 于 fn 的定义是 FnMut(Request) -> Future<Output = Result<R, E>>, 所以 fn 可以返回包含 Error 的 Result
let some_fallible_service = tower::service_fn(|_req| async {
    thing_that_might_fail().await?;
    Ok::<_, anyhow::Error>(Response::new(Body::empty()))
});

// route_service() 输入的 Service 的约束是:Service<Request, Error = Infallible> + Clone + Send +
// 'static, 这里的 Error = Infallible,与 tower::service_fn() 返回的 Error = xxx 不匹配,所以需要
// HandleError::new() 来进行转换。
let app = Router::new().route_service(
    "/",
    HandleError::new(some_fallible_service, handle_anyhow_error),
);
// 将 err 转换为实现 IntoResponse 对象的类型
async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
    (
        StatusCode::INTERNAL_SERVER_ERROR,
        format!("Something went wrong: {err}"),
    )
}

HandleError<S, F, T> 实现了 Service: F 对应的闭包函数的参数是可变的,但最后一个必须是 Error 类型,其它前面的参数类型需要实现 FromRequestParts<()> + Send,所以 F 闭包的参数还可以包含 Extractor:

pub struct HandleError<S, F, T> { /* private fields */ }

impl<S, F, T> HandleError<S, F, T>
    pub fn new(inner: S, f: F) -> Self

impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
    S: Service<Request<B>> + Clone + Send + 'static,
    S::Response: IntoResponse + Send,
    S::Error: Send,
    S::Future: Send,
    F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse,
    B: Send + 'static,

// F 闭包的参数数量是可变的,
impl<S, F, B, Res, Fut, T1> Service<Request<B>> for HandleError<S, F, (T1,)>
where
    S: Service<Request<B>> + Clone + Send + 'static,
    S::Response: IntoResponse + Send,
    S::Error: Send,
    S::Future: Send,
    F: FnOnce(T1, S::Error) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse,
    T1: FromRequestParts<()> + Send,
    B: Send + 'static

axum::error_handling::HandleErrorLayer 提供了处理 middleware Error 的转换能力。

HandleErrorLayer 实现了 Layer<S>/Service, 它内部使用 HandlerError 将传入的返回 Error 的 Service 转换为 axum 使用的 Service<Request, Error = Infallible>:

  • new(f) 输入是闭包函数 FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, 输入是 0 或多个 extractor + 最后一个 Error, 返回一个 IntoResponse 对象。
pub struct HandleErrorLayer<F, T> { /* private fields */ }

impl<F, T> HandleErrorLayer<F, T>
    // f 闭包函数的输入是 Error,返回一个 IntoResponse 对象
    pub fn new(f: F) -> Self

impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
where
    F: Clone,
{
    type Service = HandleError<S, F, T>;

    fn layer(&self, inner: S) -> Self::Service {
        // self.f 传给 HandlerError::new(), 所以需要满足它对 F 的要求
        HandleError::new(inner, self.f.clone())
    }
}

impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
    S: Service<Request<B>> + Clone + Send + 'static,
    S::Response: IntoResponse + Send,
    S::Error: Send,
    S::Future: Send,
    F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
    Fut: Future<Output = Res> + Send,
    Res: IntoResponse,
    B: Send + 'static

// 示例:
use axum::{
    Router,
    BoxError,
    routing::get,
    http::StatusCode,
    error_handling::HandleErrorLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;

let app = Router::new()
    .route("/", get(|| async {}))
    .layer(
        ServiceBuilder::new()
            // `timeout` will produce an error if the handler takes too long so we must handle
            // those new 传入的闭包函数的输入是 Error, 返回一个 IntoResponse 对象
            .layer(HandleErrorLayer::new(handle_timeout_error)) // 将 Err 转换为 IntoResponse
            .timeout(Duration::from_secs(30)) // 可能返回 Err 的 layer handler
            //.layer(TimeoutLayer::new(Duration::from_secs(10)))
    );

// 闭包函数的输入是 Error, 返回一个 IntoResponse 对象
async fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
    if err.is::<tower::timeout::error::Elapsed>() {
        (
            StatusCode::REQUEST_TIMEOUT,
            "Request took too long".to_string(),
        )
    } else {
        (
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Unhandled internal error: {err}"),
        )
    }
}

validator 和 garde crate
#

说明: validator 已经不再活跃维护,取而代之的是 https://github.com/jprochazk/garde

validator crate 提供 struct 的声明式校验能力:

#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Validate, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ApplicationIn {
    #[validate(
        length(min = 1, message = "Application names must be at least one character"),
        custom = "validate_no_control_characters"
    )]
    #[schemars(example = "application_name_example")]
    pub name: String,

    #[validate(range(min = 1, message = "Application rate limits must be at least 1 if set"))]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub rate_limit: Option<u16>,
    /// Optional unique identifier for the application
    #[validate]
    #[serde(skip_serializing_if = "Option::is_none")]
    pub uid: Option<ApplicationUid>,

    #[serde(default)]
    pub metadata: Metadata,
}

openapi
#

相关项目:

  1. oasgen: https://crates.io/crates/oasgen
  2. utoipa 和 utoipa-swagger-UI crate 为 auxm 提供了 OpenAPI 支持。
  3. aide:https://docs.rs/aide/latest/aide/
  • aide 的例子:https://github.com/svix/svix-webhooks/blob/main/server/svix-server/src/v1/endpoints/admin.rs
  1. poem_openapi 项目:https://docs.rs/poem-openapi/latest/poem_openapi/index.html

utoipa 例子:https://github.com/tokio-rs/axum/issues/50#issuecomment-2592149658

#[utoipa::path(
    get,
    path = "/",
    responses(
        (status = 200, description = "Send a salute from Axum")
    )
)]
pub async fn hello_axum() -> impl IntoResponse {
    (StatusCode::OK, "Hello, Axum")
}

// add use utoipa::OpenApi;
#[derive(OpenApi)]
#[openapi(paths(hello_axum))]
pub struct ApiDoc;

// add use utoipa_swagger_ui::SwaggerUi;
let app = Router::new()
    .route("/", get(hello_axum))
    .merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", ApiDoc::openapi()));

axum tracing log
#

配置项目 Cargo.toml 文件,为 axum 指定 tracing feature,同时引入 tower-http 的 trace feature,后续可以为 axum Router 添加 tracer middleware:

[dependencies]
axum = { path = "../../axum", features = ["tracing"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.5.0", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

示例:

// https://github.com/tokio-rs/axum/blob/main/examples/tracing-aka-logging/src/main.rs
use axum::{
    body::Bytes,
    extract::MatchedPath,
    http::{HeaderMap, Request},
    response::{Html, Response},
    routing::get,
    Router,
};
use std::time::Duration;
use tokio::net::TcpListener;
use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
use tracing::{info_span, Span};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() {
    // 设置全局 tracing provider
    tracing_subscriber::registry()
        .with(
            // 在运行时如果未指定环境变量 RUST_LOG=debug,则这里指定各 crate 的缺省值
            tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
                // axum logs rejections from built-in extractors with the `axum::rejection`
                // target, at `TRACE` level. `axum::rejection=trace` enables showing those events
                "example_tracing_aka_logging=debug,tower_http=debug,axum::rejection=trace".into()
            }),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();

    let app = Router::new()
        .route("/", get(handler))
        .layer(
            // 关键:只有加了 TraceLayer 后才会打印 HTTP 日志。
            TraceLayer::new_for_http()
                .make_span_with(|request: &Request<_>| {
                    let matched_path = request
                        .extensions()
                        .get::<MatchedPath>()
                        .map(MatchedPath::as_str);
                    info_span!(
                        "http_request",
                        method = ?request.method(),
                        matched_path,
                        some_other_field = tracing::field::Empty,
                    )
                })
                .on_request(|_request: &Request<_>, _span: &Span| {
                    // You can use `_span.record("some_other_field", value)` in one of these
                    // closures to attach a value to the initially empty field in the info_span
                    // created above.
                })
                .on_response(|_response: &Response, _latency: Duration, _span: &Span| {
                    // ...
                })
                .on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
                    // ...
                })
                .on_eos(
                    |_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {
                        // ...
                    },
                )
                .on_failure(
                    |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
                        // ...
                    },
                ),
        );

    let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

async fn handler() -> Html<&'static str> {
    Html("<h1>Hello, World!</h1>")
}

JWT Authentication
#

// https://docs.shuttle.rs/examples/axum-jwt-authentication

use axum::{
    async_trait,
    extract::FromRequestParts,
    http::{request::Parts, StatusCode},
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, RequestPartsExt, Router,
};
use axum_extra::{
    headers::{authorization::Bearer, Authorization},
    TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::time::SystemTime;

static KEYS: Lazy<Keys> = Lazy::new(|| {
    // note that in production, you will probably want to use a random SHA-256 hash or similar
    let secret = "JWT_SECRET".to_string();
    Keys::new(secret.as_bytes())
});

#[shuttle_runtime::main]
async fn main() -> shuttle_axum::ShuttleAxum {
    let app = Router::new()
        .route("/public", get(public))
        .route("/private", get(private))
        .route("/login", post(login));

    Ok(app.into())
}

async fn public() -> &'static str {
    // A public endpoint that anyone can access
    "Welcome to the public area :)"
}

// Claims 实现了 FromRequestParts,可以从请求中提取认证信息
async fn private(claims: Claims) -> Result<String, AuthError> {
    // Send the protected data to the user
    Ok(format!(
        "Welcome to the protected area :)\nYour data:\n{claims}",
    ))
}

async fn login(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
    // Check if the user sent the credentials
    if payload.client_id.is_empty() || payload.client_secret.is_empty() {
        return Err(AuthError::MissingCredentials);
    }
    // Here you can check the user credentials from a database
    if payload.client_id != "foo" || payload.client_secret != "bar" {
        return Err(AuthError::WrongCredentials);
    }

    // add 5 minutes to current unix epoch time as expiry date/time
    let exp = SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .unwrap()
        .as_secs()
        + 300;

    let claims = Claims {
        sub: "[email protected]".to_owned(),
        company: "ACME".to_owned(),
        // Mandatory expiry time as UTC timestamp - takes unix epoch
        exp: usize::try_from(exp).unwrap(),
    };

    // Create the authorization token
    let token = encode(&Header::default(), &claims, &KEYS.encoding)
        .map_err(|_| AuthError::TokenCreation)?;

    // Send the authorized token
    Ok(Json(AuthBody::new(token)))
}

// allow us to print the claim details for the private route
impl Display for Claims {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
    }
}

// implement a method to create a response type containing the JWT
impl AuthBody {
    fn new(access_token: String) -> Self {
        Self {
            access_token,
            token_type: "Bearer".to_string(),
        }
    }
}

// implement FromRequestParts for Claims (the JWT struct)
// FromRequestParts allows us to use Claims without consuming the request
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
    S: Send + Sync,
{
    type Rejection = AuthError;

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        // Extract the token from the authorization header
        let TypedHeader(Authorization(bearer)) = parts
            .extract::<TypedHeader<Authorization<Bearer>>>()
            .await
            .map_err(|_| AuthError::InvalidToken)?;
        // Decode the user data
        let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
            .map_err(|_| AuthError::InvalidToken)?;

        Ok(token_data.claims)
    }
}

// implement IntoResponse for AuthError so we can use it as an Axum response type
impl IntoResponse for AuthError {
    fn into_response(self) -> Response {
        let (status, error_message) = match self {
            AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
            AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
            AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
            AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
        };
        let body = Json(json!({
            "error": error_message,
        }));
        (status, body).into_response()
    }
}

// encoding/decoding keys - set in the static `once_cell` above
struct Keys {
    encoding: EncodingKey,
    decoding: DecodingKey,
}

impl Keys {
    fn new(secret: &[u8]) -> Self {
        Self {
            encoding: EncodingKey::from_secret(secret),
            decoding: DecodingKey::from_secret(secret),
        }
    }
}

// the JWT claim
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,
    company: String,
    exp: usize,
}

// the response that we pass back to HTTP client once successfully authorised
#[derive(Debug, Serialize)]
struct AuthBody {
    access_token: String,
    token_type: String,
}

// the request type - "client_id" is analogous to a username, client_secret can also be interpreted as a password
#[derive(Debug, Deserialize)]
struct AuthPayload {
    client_id: String,
    client_secret: String,
}

// error types for auth errors
#[derive(Debug)]
enum AuthError {
    WrongCredentials,
    MissingCredentials,
    TokenCreation,
    InvalidToken,
}N

参考
#

rust crate - 这篇文章属于一个选集。
§ 10: 本文

相关文章

clap
·
clap 用于快速构建命令行程序,提供命令&参数定义、解析等功能。
http/http_body/http_body_util crate
·
http/http_body/http_body_util crate 是公共的 http 和 body 定义。在 tokio 系列项目,如 hyper/axum/reqwest 中得到广泛应用,这些 crate 通过 import + pub use 的方式导入和使用 http/http_body/http_body_util。
reqwest
·
reqwest 是在 hype crate 基础上实现的高层 HTTP client 库,支持异步和同步请求。
serde_json
·
serde_json crate 解析。