axum 是基于 hyper 实现的异步 HTTP Server 库,使用流程:
-
创建
Router:通过Route.route()来定义PATH和关联的Service。Service一般使用RouterMethod类型实现,get/post/patch()等函数返回RouterMehtod对象,这些函数的参数类型是Handler; -
实现
Handler:一般由异步闭包函数实现:- 输入:
Extractor,用来从请求中提取相关信息; - 返回:实现
IntoResponse trait的对象(而不是 Result)。axum 为 Rust 基本类型实现了该 trait;
- 输入:
Router/RouterMethod/Handler 三级都可以:
- 通过
layer()方法来添加中间件,从而在调用Handler前先做一些处理; - 通过
with_state()添加状态对象;
use axum::{Router, routing::get};
let app = Router::new()
.route("/", get(root))
// get() 方法处理 GET 类型请求,返回 RouterMethod 类型(实现了 Service)。可以链式调用。
.route("/foo", get(get_foo).post(post_foo))
.route("/foo/bar", get(foo_bar));
// 实现 Handler trait 的闭包
async fn root() {}
async fn get_foo() {}
async fn post_foo() {}
async fn foo_bar() {}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
axum::serve #
axum::serve() 是 auxm 入口,它的第二个参数 make_service 是 Service 工厂(Service of Service)。
内层 Service 的输入是 http::request::Request<http_body::Body<Data=bytes::Bytes>>,响应是 http::response::Response<axum::body::Body>,也即从 HTTP Request 生成 Response。
pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
where
// 外层 Service:
// Request 是 IncomingStream 类型(参考后文),Response 是另一个 Service
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
// 内层 Service:
// Request 是类型 http::request::Request<http_body::Body<Data=bytes::Bytes>>
// Response 是类型 http::response::Response<axum::body::Body>
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
axum 的 Router/MethodRouter/Handler 等类型都满足 M 类型约束,也即实现了 Service 工厂, 它们的值均可以作为 serve() 的参数。
以 Router 为例,它实现了两层 Service trait,满足 M 的限界要求:
// Router 实现了 Service<IncomingStream<'_>>,对应的 Response 还是 Router 类型
impl Service<IncomingStream<'_>> for Router<()>
type Response = Router // 响应为 Router
type Error = Infallible
type Future = Ready<Result<<Router as Service<IncomingStream<'_, L>>>::Response, <Router as Service<IncomingStream<'_, L>>>::Error>>
// Router 又实现了 Service<Request<B>>, 满足上面 M 的 Response=S 中的 S 约束:
// Request 为 http::request::Request<http_body::Body<Data=bytes::Bytes>>
// Response 为 http::response::Response<axum::body::Body>
impl<B> Service<Request<B>> for Router<()>
where
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>
type Response = Response<Body> // Body 为 axum::body::Body 类型
type Error = Infallible
type Future = RouteFuture<Infallible>
示例:
// Serving a Router:
use axum::{Router, routing::get};
let router = Router::new().route("/", get(|| async { "Hello, World!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
// Serving a MethodRouter:
use axum::routing::get;
let router = get(|| async { "Hello, World!" });
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();
Router #
Router<S=()> 用于定义 PATH 和处理逻辑,S 为 State 的类型,可以是实现 Clone 的任意自定义类型(一般是 Arc)。
// S 为 Router 的 State,缺省为 ();
pub struct Router<S = ()> { /* private fields */ }
处理逻辑的实现方式:
-
route(): 使用MethodRouter<S>处理逻辑, 一般通过axum::routing::method_routing module提供的各种HTTP Method命名的方法来实现, 例如:axum::routing::method_routing::get(handler: Handler<T, S>):handler 一般通过闭包来实现;axum::routing::method_routing::get_service(svc: Service<Request>): svc 一般通过tower::service_fn()闭包来实现;axum::routing::method_routing::any_service(svc: Service<Request>):任意请求类型的 svc,一般通过tower::service_fn()闭包来实现;axum::routing::method_routing::on_service<T, S>(filter: MethodFilter, svc: T,):指定 Method 和 svc,svc 一般通过tower::service_fn()闭包来实现;
-
router_service(): 使用tower::Service处理逻辑, 一般通过tower::service_fn()闭包来实现, 或者直接复用tower_http crate提供的Service, 如tower_http::services::ServeFile -
layer()/route_layer(): 使用Layer<Route>处理逻辑, 一般通过axum::middleware::from_fn()/from_fn_with_state()闭包来实现, 或者直接复用tower_http crate提供的Layer, 如tower_http::trace::TraceLayer;
- 不直接使用
tower::layer_fn()来实现,因为它的Request/Response是任意类型, 可能不满足 axum 要求的http::request::Request/http::request::Response类型。
总结:上面的各种处理逻辑类型, 如 Handler, Service, Layer,都可以使用闭包来快速实现。
// S 是后续 with_state(state: S) 方法传入的 state 对象的类型, 需要实现 Clone, 所以一般使用 Arc 类型。
impl<S> Router<S> where S: Clone + Send + Sync + 'static
// 以下方法是在 Router<S=()> 类型上定义的:
// 创建一个 Router<S=()> 类型的 Router 对象, 这里的 S 类型默认为 ()
pub fn new() -> Self
// 添加一个对 path 的 MethodRouter 处理逻辑,可以使用 any_service() 函数从闭包创建
pub fn route(self, path: &str, method_router: MethodRouter<S>) -> Self
// 添加一个对 path 的 Service 处理,可以使用 tower::service_fn() 从闭包函数创建, 也可以使用 tower_http 的各种 middleware。
//
// 注意:Service 的 Error = Infallible, 如果 tower::service_fn() 闭包的返回值 Result 的 Err 不为 Infallible,则不满足该 Serivce 限界,需要使用 axum::error_handling::HandleError 来转换。
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static
// 添加嵌套的 Router
pub fn nest(self, path: &str, router: Router<S>) -> Self
pub fn nest_service<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static
// 将多个 Router 的合并到一起
pub fn merge<R>(self, other: R) -> Self where R: Into<Router<S>>
// 为 Router 的所有 Route 添加 layer middleware(而不管它们是否匹配该 Router)。
// 按照添加的反序来调用,最后才调用 handler。
// layer() 获取所有权,返回一个新 Router<S>,故可以链式调用。
pub fn layer<L>(self, layer: L) -> Router<S>
where
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static
// 只为匹配 Route 的请求添加 layer。
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static
// 没有匹配的 Route 时 fallback 到的 handler
pub fn fallback<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static
pub fn fallback_service<T>(self, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static
// 为 Router 提供 state,后续通过 handler 的 extract 来获得该 state, state 可以是实现 Clone 的任意自定义对象类型。
// 为 Router 的所有请求提供全局数据,不适合给特定单个请求提供数据(用 Extension)。
// S2 的类型是靠后续使用 Router 的方式自动推导的。
pub fn with_state<S2>(self, state: S) -> Router<S2>
// 在 Router 上直接调用 tower::ServiceExt 的方法会报错(隐式类型推导失败),需要调用该方法来解决。
pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S>
// !!! 以下两个方法是在 Router<()> 类型上定义的: 将 Route 转换为 MakeService, 它是创建另一个 Service 的 Service, 主要的使用场景是作为 axum::serve 的参数。
pub fn into_make_service(self) -> IntoMakeService<Self>
pub fn into_make_service_with_connect_info<C>( self,) -> IntoMakeServiceWithConnectInfo<Self, C>
Router 虽然实现了 Service<Request<B>> 即 Service<http::request::Request<B:http_body::Body<Data=bytes.Bytes>> , 但是直接调用 tower::ServiceExt 的方法会报错(如 service.ready().wait?.call(request))。
解决办法: 使用 Router 的 as_service() 方法返回的 Service:
use axum::{
Router,
routing::get,
http::Request,
body::Body,
};
use tower::{Service, ServiceExt};
let mut router = Router::new().route("/", get(|| async {}));
let request = Request::new(Body::empty());
// let response = router.ready().await?.call(request).await?;
// ^^^^^ cannot infer type for type parameter `B`
// OK
let response = router.as_service().ready().await?.call(request).await?;
Router<S=()> 提供了 into_make_service()/into_make_service_with_connect_info() 方法来创建一个实现 MakeService trait 的类型:
// 返回一个 Service 工厂类型, 泛型参数 Self 为 Router 类型
pub fn into_make_service(self) -> IntoMakeService<Self>
// S 为 Router 类型时(因为是调用 Router::into_make_service() 方法),Response 也为 Router 类型,
// 而 Router 实现了 MakeService trait. 所以 IntoMakeService 实现了 Service 和 MakeService trait。
impl<S, T> Service<T> for IntoMakeService<S> where S: Clone
type Response = S
type Error = Infallible
type Future = IntoMakeServiceFuture<S>
// 示例
use axum::{
routing::get,
Router,
};
let app = Router::new().route("/", get(|| async { "Hi!" }));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// axum::serve(listener, app).await.unwrap();
axum::serve(listener, app.into_make_service()).await.unwrap();
更多示例:
use axum::{
Router,
body::Body,
routing::{get, delete, any_service, get_service, MethodFilter, on_service},
extract::{Request, Path, State},
http::StatusCode,
error_handling::HandleErrorLayer,};
use tower::{Service, ServiceExt};
use tower::service_fn;
use tower_http::services::ServeFile;
use tower_http::trace::TraceLayer;
use tower_http::validate_request::ValidateRequestHeaderLayer;
use http::Response;
use std::{convert::Infallible, io};
let app = Router::new()
.route("/", get(root))
// get() 返回 MethodRouter,可以链式调用
.route("/users", get(list_users).post(create_user))
.route("/users/:id", get(show_user))
.route("/api/:version/users/:id/action", delete(do_users_action))
// /assets/*path 不匹配 /assets/
.route("/assets/*path", get(serve_asset));
async fn root() {}
async fn list_users() {}
async fn create_user() {}
async fn show_user(Path(id): Path<u64>) {}
// 多个 Path 参数用 tuple 表示
async fn do_users_action(Path((version, id)): Path<(String, u64)>) {}
async fn serve_asset(Path(path): Path<String>) {}
let app = Router::new()
.route( "/", any_service( // any_service() 从闭包创建一个 MethodRouter 对象
// 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样才匹配类型约束。否则需要使用
// axum::error_handling::HandleError 来转换 Err 为 Infallible。
service_fn(|_: Request| async {
// 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response 对象。
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
})))
.route_service( "/foo", service_fn(|req: Request| async move {
// 使用 axum::body::Body 类型,它实现了 http_body::Body trait
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let res = Response::new(body);
Ok::<_, Infallible>(res)
}))
.route_service( "/static/Cargo.toml", ServeFile::new("Cargo.toml"), );
// on_service 是通用的请求函数,需要指定具体的 HTTP Method,
// any_service 是更通用的请求函数, 匹配任意类型 HTTP Method。
let service = tower::service_fn(|request: Request| async { Ok::<_, Infallible>(Response::new(Body::empty()))});
let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
// merge 两个 Router
let user_routes = Router::new() .route("/users", get(users_list)) .route("/users/:id", get(users_show));
let team_routes = Router::new() .route("/teams", get(teams_list));
let app = Router::new().merge(user_routes) .merge(team_routes);
let app = Router::new()
.route("/foo", get(|| async {}))
.route("/bar", get(|| async {}))
.layer(TraceLayer::new_for_http()); // 对上面两个 path 的 route 均生效
let app = Router::new()
.route("/foo", get(|| async {}))
.route_layer(ValidateRequestHeaderLayer::bearer("password")); // 只对 /foo 生效
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `GET /not-found` with a invalid token will receive `404 Not Found` // 因为没有匹配 /not-found 的 Router,所以返回 404
let app = Router::new()
.route("/foo", get(|| async { /* ... */ }))
.fallback(fallback); // 未匹配的 path 的缺省 handler
async fn fallback(uri: Uri) -> (StatusCode, String) {}
// state 对象必须实现 Clone trait
#[derive(Clone)]
struct AppState {}
let routes = Router::new()
// 使用 axum::extract::State 来为请求获得 global state
.route("/", get(|State(state): State<AppState>| async {
// 使用 state
}))
.with_state(AppState {});
IncomingStream #
Struct axum::serve::IncomingStream 是 MethodRouter/HandlerService 等类型在实现 Service 时的 Request 类型,如:impl Service<IncomingStream<'_>> for MethodRouter<()>。
IncomingStream 封装了 TCP/Unix Stream Listener,对外提供 accept() 和 local_addr() 方法:
pub struct IncomingStream<'a, L> where L: Listener, { /* private fields */ }
// IncomingStream 实现了如下方法:
// TCP/Unix Stream, 可以获得 local addr 等信息
pub fn io(&self) -> &L::Io
// 连接的 client addr
pub fn remote_addr(&self) -> &L::Addr
// https://docs.rs/axum/latest/axum/serve/trait.Listener.html
pub trait Listener: Send + 'static {
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
type Addr: Send;
// Required methods
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
fn local_addr(&self) -> Result<Self::Addr>;
}
// tokio::net::TcpListener 和 UdpListener 实现了 Listener trait:
impl Listener for TcpListener
type Io = TcpStream
type Addr = SocketAddr
async fn accept(&mut self) -> (Self::Io, Self::Addr)
fn local_addr(&self) -> Result<Self::Addr>
impl Listener for UnixListener
type Io = UnixStream
type Addr = SocketAddr
async fn accept(&mut self) -> (Self::Io, Self::Addr)
fn local_addr(&self) -> Result<Self::Addr>
ConnectInfo #
Router<()> 的 into_make_service_with_connect_info<C>() 方法的使用场景是为 Handler 的 ConnectInfo extractor 提供信息(通过 HTTP Request Extensions 实现)client 连接信息,如 socket 地址:
// 在 Router<()> 类型上定义的方法:
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>
// 示例
use axum::{
extract::ConnectInfo,
routing::get,
Router,
};
use std::net::SocketAddr; // 直接支持
// new() 创建一个 Router 对象, 但是没有 state 的类型信息, 所以后续根据使用 app 的方式来推导, 而
// app.into_make_service_with_connect_info() 方法是在 Router<()> 类型上定义的, 所以 app 被推导为 Router<()> 类型.
let app = Router::new().route("/", get(handler));
async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
format!("Hello {addr}")
}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await.unwrap();
可以通过实现 Connected<IncomingStream<'_>> trait 来自定义 ConnectInfo 类型。
axum 为 std::net::SocketAddr 和 IncomingStream<'_> 类型实现了 Connected trait,可以直接使用;
// 实例:
use axum::{
extract::connect_info::{ConnectInfo, Connected},
routing::get,
serve::IncomingStream,
Router,
};
let app = Router::new().route("/", get(handler));
async fn handler(ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>, ) -> String {
format!("Hello {my_connect_info:?}")
}
#[derive(Clone, Debug)]
struct MyConnectInfo {
// ...
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
// target 封装了 TCP/Unix Stream 和 local_addr/remote_addr 信息
MyConnectInfo {
// ...
}
}
}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();
body #
Handler 需要返回实现 IntoResponse trait 的对象,而 http::Response<Body: http_body::Body> 是泛型类型,需要指定 Body 的具体类型。
pub trait IntoResponse {
// Required method
fn into_response(self) -> http::Response<auxm::body::Body>;
}
axum 使用自己定义的,实现了 http_body::Body<bytes::Bytes> trait 的 Struct axum::body::Body 类型作为 Handler 返回值。
axum::body::Body实现FromRequest trait,可以用作extractor,同时也实现了IntoResponse trait,可以用作Handler响应。
// Struct axum::body::Body
pub struct Body(/* private fields */);
// 从实现了 http_body::Body 的对象创建 auxm::body::Body 对象, new() 是泛型方法, 泛型类型 B 需要实现 http_body::Body<Data>
// trait, 且 Data 缺省类型为 bytes::Bytes, 一般情况下 B 的类型为 http_body_util::Full/http_body_util::empty 等类型.
pub fn new<B>(body: B) -> Body where B: Body<Data = Bytes> + Send + 'static, <B as Body>::Error: Into<Box<dyn Error + Sync + Send>>
// 将 Body Data 部分转换为 Stream(丢弃 body 中的 trailers 内容) 可以使用 http_body_util/struct.BodyStream 来将 Body 转
// 换为带有 trailers 的 https://docs.rs/http-body/1.0.1/http_body/struct.Frame.html 类型.
pub fn into_data_stream(self) -> BodyDataStream
// 其它快速创建 axum::body::Body 的方式
impl From<&'static [u8]> for Body
impl From<&'static str> for Body
impl From<()> for Body
impl From<Bytes> for Body
impl From<Cow<'static, [u8]>> for Body
impl From<Cow<'static, str>> for Body
impl From<String> for Body
impl From<Vec<u8>> for Body
// axum::body::Body 实现了 FromRequest trait,故可以作为 extractor 来使用
impl<S> FromRequest<S> for Body where S: Send + Sync
type Rejection = Infallible
fn from_request<'life0, 'async_trait>(req: Request<Body>, _: &'life0 S)
-> Pin<Box<dyn Future<Output = Result<Body, <Body as FromRequest<S>>::Rejection>> + Send + 'async_trait>> 'life0: 'async_trait Body: 'async_trait
// axum::body::Body 实现了 IntoResponse,可以作为 Handler 返回值
impl IntoResponse for Body
fn into_response(self) -> Response<Body>
// 示例
let app = Router::new()
.route( "/", any_service(
// 从闭包创建 Service,闭包返回的 Result 的 Err 必须是 Infallible 类型,这样才匹配类型约束。否则需要使用
// axum::error_handling::HandleError 来转换 Err 为 Infallible。
service_fn(|_: Request| async {
// 先创建一个 auxm::body::Body 对象,然后用它创建一个 http::Response<T> 对象
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
})))
MethodRouter #
MethodRouter 是 Router::route(path, method_router) 方法的参数类型,为 path 提供处理逻辑。
MethodRouter<S, Infallible> 封装了请求 Method 及 Handler 处理逻辑,可以链式调用,根据 Method 来调用不同的 Hander 来处理。
// S 是 State 的类型,一般由 Router<S> 传递下来。
impl<S> MethodRouter<S, Infallible> where S: Clone
auxm::routing module 提供了一些快捷函数,如 any()/any_service()/get()/get_service()/delete()/delete_service()/put()/post() 等,可以用来快速创建对应 HTTP Method 的 MethodRouter 对象:
// Re-exports
pub use self::method_routing::any;
pub use self::method_routing::any_service;
pub use self::method_routing::delete;
pub use self::method_routing::delete_service;
pub use self::method_routing::get;
pub use self::method_routing::get_service;
pub use self::method_routing::head;
pub use self::method_routing::head_service;
pub use self::method_routing::on;
pub use self::method_routing::on_service;
pub use self::method_routing::options;
pub use self::method_routing::options_service;
pub use self::method_routing::patch;
pub use self::method_routing::patch_service;
pub use self::method_routing::post;
pub use self::method_routing::post_service;
pub use self::method_routing::put;
pub use self::method_routing::put_service;
pub use self::method_routing::trace;
pub use self::method_routing::trace_service;
pub use self::method_routing::MethodFilter; // on()/on_service() 使用的请求方法类型(关联常量)
// MethodRouter 的方法:
//
// on()/on_service() 是通用方法,MethodFilter 是枚举类型,为标准的 HTTP Method。
// Request 是 http::Request<axum::body::Body> 类型。
pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
where
T: Service<Request, Error = E> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static
// any()/any_service() 是限制 Method 的类型方法。
pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>where
H: Handler<T, S>,
T: 'static,
S: Clone + Send + Sync + 'static
pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>where
T: Service<Request> + Clone + Send + Sync + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
S: Clone,
// 其它返回 MethodRouter 的方法,它们是 on() 方法的封装,可以链式调用
pub fn delete<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn get<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
pub fn head<H, T>(self, handler: H) -> Self where H: Handler<T, S>, T: 'static, S: Send + Sync + 'static
tower::service_fn(f: T) 的参数 f 是实现 FnMut(Request) -> Future<Output = Result<R, E>> 的闭包:
- 在 axum 场景中,
Request需要是http::Request<axum::body::Body>类型,或它的类型别名axum::extract::Request<axum::body::Body>
// tower::service_fn() 定义
pub fn service_fn<T>(f: T) -> ServiceFn<T> // T 是一个返回 Future 的异步函数。
// ServiceFn<T> 实现了 Service trait,其中 Request 和响应 R 都是泛型类型,
// 在 axum 场景中,Request 需要是 http::Request<axum::body::Body> 类型,或它的类型别名 axum::extract::Request<axum::body::Body>
impl<T, F, Request, R, E> Service<Request> for ServiceFn<T> where T: FnMut(Request) -> F, F: Future<Output = Result<R, E>>,
{
type Response = R;
type Error = E;
type Future = F;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), E>> {
Ok(()).into()
}
fn call(&mut self, req: Request) -> Self::Future {
(self.f)(req)
}
}
以 get()/get_service() 为例:
// `get()`: 使用 `Handler` 处理逻辑;
pub fn get<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
where
H: Handler<T, S>,
T: 'static,
S: Clone + Send + Sync + 'static
// `get_service()`: 使用 `Service` 处理逻辑, 一般通过 `tower::service_fn()` 闭包来实现;
// Request 是 axum::extract::Request<axum::body::Body> 类型,而它是 http::Request<axum::body::Body> 的类型别名。
pub fn get_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
where
T: Service<Request> + Clone + Send + 'static,
T::Response: IntoResponse + 'static,
T::Future: Send + 'static,
S: Clone,
// 示例:
use axum::{
routing::get,
routing::get_service,
extract::Request,
body::Body,
Router,
};
async fn handler() {}
async fn other_handler() {}
let service = tower::service_fn(|request: Request| async {
// 返回值 Result 的 Err 必须是 Infallible
Ok::<_, Infallible>(Response::new(Body::empty()))
});
let app = Router::new()
.route("/", get(handler))
.route("/svc", get_service(service).on(MethodFilter::DELETE, other_handler));
MethodRouter 支持添加 state 和 layer,且只对该 MethodRouter 的 Handler 有效:
// 关联 State
pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E>
// 关联 Layer,Layer 在 Handler 之前处理
pub fn route_layer<L>(self, layer: L) -> MethodRouter<S, E>
where
L: Layer<Route<E>> + Clone + Send + 'static,
L::Service: Service<Request, Error = E> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
E: 'static,
S: 'static,
// 示例
use axum::{ routing::get, Router, };
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new().route(
"/foo",
get(|| async {}).route_layer(ValidateRequestHeaderLayer::bearer("password"))
);
// `GET /foo` with a valid token will receive `200 OK`
// `GET /foo` with a invalid token will receive `401 Unauthorized`
// `POST /FOO` with a invalid token will receive `405 Method Not Allowed`
MethodRouter 只在 state 为 () 时实现了 Service trait。对于非 () 的情况, 需要使用 with_state() 传入对应类型的 state 值后,该 MethodRouter 才实现 Service:
impl<S, E> MethodRouter<S, E> where S: Clone,
pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E>
// MethodRouter 只在 state 为 () 时实现了 Service trait
impl<B, E> Service<Request<B>> for MethodRouter<(), E>
where
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>
impl<L> Service<IncomingStream<'_, L>> for MethodRouter<()>
Where
L: Listener,
// 示例:
use tower::Service;
use axum::{routing::get, extract::{State, Request}, body::Body};
// this `MethodRouter` doesn't require any state, i.e. the state is `()`,
let method_router = get(|| async {});
// and thus it implements `Service`
assert_service(method_router);
// this requires a `String` and doesn't implement `Service`
let method_router = get(|_: State<String>| async {});
// until you provide the `String` with `.with_state(...)`
let method_router_with_state = method_router.with_state(String::new());
// and then it implements `Service`
assert_service(method_router_with_state);
// helper to check that a value implements `Service`
fn assert_service<S>(service: S) where S: Service<Request>, {}
Handler #
Handler trait 是 MethodRouter 各方法 on/get/delete/put/post(handler) 的参数实现的处理逻辑,一般由 async 闭包函数实现。
Request是类型别名:pub type Request<T = axum::body::Body> = http::Request<T>;Response是类型别名:pub type Response<T = axum::body::Body> = http::Response<T>;
pub trait Handler<T, S>: Clone + Send + Sized + 'static {
type Future: Future<Output = Response> + Send + 'static;
// Required method
fn call(self, req: Request, state: S) -> Self::Future;
// Provided methods
// 返回的 Layered 也实现了 Handler,后续先执行传入的 layer 再执行本 Handler
fn layer<L>(self, layer: L) -> Layered<L, Self, T, S> where L: Layer<HandlerService<Self, T, S>> + Clone, L::Service: Service<Request>
fn with_state(self, state: S) -> HandlerService<Self, T, S> { ... }
}
// MethodRouter 实现了 Handler
impl<S> Handler<(), S> for MethodRouter<S> where S: Clone + 'static
Handler 提供了 layer() 和 with_state() 方法,用来为该 Handler 添加中间件(在执行 handler 闭包前执行)和状态:
- State 需要满足
Clone + Send + Sync + 'static,一般需要获得 State 的所有权(如 Arc)才能满足要求。
// 示例:
use axum::{routing::get, handler::Handler, Router, };
use tower::limit::{ConcurrencyLimitLayer, ConcurrencyLimit};
async fn handler() { /* ... */ }
// 返回的 layered_handler 实现了 Handler,后续先执行传入的 layer 再执行 handler。
let layered_handler = handler.layer(ConcurrencyLimitLayer::new(64));
let app = Router::new().route("/", get(layered_handler));
with_state() 返回 HandlerService<Self, T, S> 类型对象,它提供了 into_make_service()/into_make_service_with_connect_info<C>() 方法。
HandlerService 用于将 Handler 对象转换为 Service of Service 对象, 可以作为 axum::serve() 的参数:
// An adapter that makes a Handler into a Service.
// Created with Handler::with_state or HandlerWithoutStateExt::into_service.
pub struct HandlerService<H, T, S> { /* private fields */ }
impl<H, T, S> HandlerService<H, T, S>
pub fn state(&self) -> &S
pub fn into_make_service(self) -> IntoMakeService<HandlerService<H, T, S>>
pub fn into_make_service_with_connect_info<C>(self)->IntoMakeServiceWithConnectInfo<HandlerService<H,T,S>, C>
// HandlerService 也是 Service of Service 工厂, 可以直接给 axum::serve() 使用。
impl<H, T, S> Service<IncomingStream<'_>> for HandlerService<H, T, S>
where
H: Clone,
S: Clone
type Response = HandlerService<H, T, S> // 返回自身类型
type Error = Infallible
impl<H, T, S, B> Service<Request<B>> for HandlerService<H, T, S>
where
H: Handler<T, S> + Clone + Send + 'static,
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
S: Clone + Send + Sync
type Response = Response<Body>
type Error = Infallible
// with_state() 示例
use axum::{
handler::Handler,
response::IntoResponse,
extract::{ConnectInfo, State},
};
use std::net::SocketAddr;
// State 需要实现 Clone
#[derive(Clone)]
struct AppState {};
async fn handler( ConnectInfo(addr): ConnectInfo<SocketAddr>, State(state): State<AppState>, ) -> String {
format!("Hello {addr}")
}
let app = handler.with_state(AppState {}); // 为 handler 关联 State
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// axum::handler::HandlerWithoutStateExt trait 为 Handler 提供了 into_service()/into_make_service()/into_make_service_with_connect_info() 等方法,返回 HandlerService。
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>(),).await.unwrap();
axum::handler::HandlerWithoutStateExt trait 为没有 State 的 Handler<T, ()> 对象提供了 into_service()/into_make_service() 方法,返回 HandlerService。
pub trait HandlerWithoutStateExt<T>: Handler<T, ()> {
// Required methods
fn into_service(self) -> HandlerService<Self, T, ()>;
fn into_make_service(self) -> IntoMakeService<HandlerService<Self, T, ()>>;
fn into_make_service_with_connect_info<C>(self, ) -> IntoMakeServiceWithConnectInfo<HandlerService<Self, T, ()>, C>;
}
// 示例:
// Serving a Handler: 需要调用 handler 的 into_make_service() 方法
use axum::handler::HandlerWithoutStateExt;
async fn handler() -> &'static str { "Hello, World!"}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, handler.into_make_service()).await.unwrap();
axum 默认为 16 个参数内的 async FnOnce 闭包实现了 Handler trait:
- 闭包的输入参数是
extractor,可以有多个,但前面的参数必须实现FromRequestParts trait,最后一个参数实现FromRequest trait; - 闭包返回的
Output必须实现Future<Output = IntoResponse> + Clone + Send + 'static, 所以一般不包含借用, 而是使用move将所有权转移到闭包中;
这些 async FnOnce 函数或闭包返回的结果是 Future<Output = IntoResponse> ,而不是 Result 类型,所以如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse trait。
impl<F, Fut, Res, S> Handler<((),), S> for F
where
F: FnOnce() -> Fut + Clone + Send + 'static, // 闭包函数
Fut: Future<Output = Res> + Send,
Res: IntoResponse
// 如果是一个参数,则需要实现 FromRequest
impl<F, Fut, S, Res, M, T1> Handler<(M, T1), S> for F
where
F: FnOnce(T1) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
S: Send + Sync + 'static,
Res: IntoResponse,
T1: FromRequest<S, M> + Send
// 对于多个参数,前面的参数实现 FromRequestParts, 最后一个参数实现 FromRequest
impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
S: Send + Sync + 'static,
Res: IntoResponse,
T1: FromRequestParts<S> + Send,
T2: FromRequest<S, M> + Send
// 一直到 16 个输入参数的闭包类型
impl<F, Fut, S, Res, M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16> Handler<(M, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16), S> for F
where
F: FnOnce(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
S: Send + Sync + 'static,
Res: IntoResponse,
T1: FromRequestParts<S> + Send,
T2: FromRequestParts<S> + Send,
T3: FromRequestParts<S> + Send,
T4: FromRequestParts<S> + Send,
T5: FromRequestParts<S> + Send,
T6: FromRequestParts<S> + Send,
T7: FromRequestParts<S> + Send,
T8: FromRequestParts<S> + Send,
T9: FromRequestParts<S> + Send,
T10: FromRequestParts<S> + Send,
T11: FromRequestParts<S> + Send,
T12: FromRequestParts<S> + Send,
T13: FromRequestParts<S> + Send,
T14: FromRequestParts<S> + Send,
T15: FromRequestParts<S> + Send,
T16: FromRequest<S, M> + Send
async 闭包实现的 Handler 示例:
use axum::{body::Bytes, http::StatusCode};
// 空返回值表示返回 200 OK,body 为空
async fn unit_handler() {}
// 返回 String 表示返回 200 OK 的同时返回纯文本 body
async fn string_handler() -> String {
"Hello, World!".to_string()
}
// echo handler 只有一个参数,它必须实现 FromRequest(Bytes 满足)。
// String 和 StatusCode 都实现了 IntoResponse, 故 Result<String, StatusCode> 也实现了 IntoResponse。
async fn echo(body: Bytes) -> Result<String, StatusCode> {
if let Ok(string) = String::from_utf8(body.to_vec()) {
Ok(string)
} else {
Err(StatusCode::BAD_REQUEST)
}
}
IntoResponse #
Handler 闭包函数的返回值需要实现 IntoResponse trait:
pub trait IntoResponse {
// Required method
fn into_response(self) -> Response<Body>; // Body 为 struct axum::body::Body
}
如两个参数的闭包示例:
impl<F, Fut, S, Res, M, T1, T2> Handler<(M, T1, T2), S> for F
where
F: FnOnce(T1, T2) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
S: Send + Sync + 'static,
Res: IntoResponse, // 需要实现 IntoResponse
T1: FromRequestParts<S> + Send,
T2: FromRequest<S, M> + Send
axum 为 Rust 基本类型实现了 IntoResponse trait:
impl<const N: usize> IntoResponse for &'static [u8; N]
impl IntoResponse for &'static [u8]
impl<const N: usize> IntoResponse for [u8; N]
impl IntoResponse for &'static str
impl IntoResponse for &'static [u8]
impl IntoResponse for Cow<'static, str>
impl IntoResponse for Cow<'static, [u8]>
impl IntoResponse for Infallible
impl IntoResponse for ()
impl IntoResponse for Box<str>
impl IntoResponse for Box<[u8]>
impl IntoResponse for String
impl IntoResponse for Vec<u8>
impl IntoResponse for Bytes
impl IntoResponse for BytesMut
impl IntoResponse for Extensions
impl IntoResponse for HeaderMap
impl IntoResponse for Parts
impl IntoResponse for StatusCode
// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse
// struct axum::response::ErrorResponse 也实现了 From<IntoResponse>
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse
impl<T, U> IntoResponse for Chain<T, U> where
T: Buf + Unpin + Send + 'static,
U: Buf + Unpin + Send + 'static,
// http::response::Response<bytes::Bytes> 实现了 IntoResponse
impl<B> IntoResponse for Response<B>
where
B: Body<Data = Bytes> + Send + 'static,
<B as Body>::Error: Into<Box<dyn Error + Send + Sync>>
// (K, V) 数组也实现 IntoResponse: K 可转换为 HeaderName, V 可转换为 HeaderValue
impl<K, V, const N: usize> IntoResponse for [(K, V); N]
where
K: TryInto<HeaderName>,
<K as TryInto<HeaderName>>::Error: Display,
V: TryInto<HeaderValue>,
<V as TryInto<HeaderValue>>::Error: Display
// 特殊的 tuple 类型实现 IntoResponse, tuple 可以包含 T1..T16 个可选参数;
// Parts 为 http::response::Parts 类型, 包含响应的 header, status code
// ResponseParts 为 axum::response::ResponseParts struct 类型, 用于设置 header 和 extentions.
impl<R> IntoResponse for (Parts, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Parts, T1, R) where T1: IntoResponseParts, R: IntoResponse
// Response<()> 为 http::response::Response<()> 类型
impl<R> IntoResponse for (Response<()>, R) where R: IntoResponse
impl<R, T1> IntoResponse for (Response<()>, T1, R) where T1: IntoResponseParts, R: IntoResponse
// 自定义响应 StatusCode 的场景
impl<R> IntoResponse for (StatusCode, R) where R: IntoResponse
impl<R, T1> IntoResponse for (StatusCode, T1, R) where T1: IntoResponseParts, R: IntoResponse
// 自定义响应 Header 的场景
impl<R, T1> IntoResponse for (T1, R) where T1: IntoResponseParts, R: IntoResponse
impl<R> IntoResponse for (R,) where R: IntoResponse
// axum::response::ResponseParts 类型
pub struct ResponseParts { /* private fields */ }
impl ResponseParts
pub fn headers(&self) -> &HeaderMap
pub fn headers_mut(&mut self) -> &mut HeaderMap
pub fn extensions(&self) -> &Extensions
pub fn extensions_mut(&mut self) -> &mut Extensions
// IntoResponseParts 用于为传入的 ResponseParts 添加 header 和 extentions
pub trait IntoResponseParts {
type Error: IntoResponse;
// Required method
fn into_response_parts( self, res: ResponseParts, ) -> Result<ResponseParts, Self::Error>;
}
impl IntoResponseParts for ()
impl IntoResponseParts for Extensions
impl IntoResponseParts for HeaderMap
impl<K, V, const N: usize> IntoResponseParts for [(K, V); N]
// T1..T16 个可选参数
impl<T1> IntoResponseParts for (T1,) where T1: IntoResponseParts,
其它实现 IntoResponse 的类型,如 Json/Form/Html/Body 等:
// Extractor trait 内部的 Rejection 关联类型
impl IntoResponse for MultipartRejection
impl IntoResponse for BytesRejection
impl IntoResponse for ExtensionRejection
impl IntoResponse for FailedToBufferBody
// ...
// 其它 axum 定义的类型
impl IntoResponse for Body // struct axum::body::Body
impl IntoResponse for Redirect
impl<T> IntoResponse for Extension<T> where T: Clone + Send + Sync + 'static,
impl<T> IntoResponse for Form<T> where T: Serialize
impl<T> IntoResponse for Json<T> where T: Serialize // axum 自定义的 Json 类型 Struct axum::Json
impl<T> IntoResponse for Html<T> where T: Into<Body>
示例:
//https://docs.rs/axum/latest/axum/response/index.html
use axum::{
Json,
response::{Html, IntoResponse},
http::{StatusCode, Uri, header::{self, HeaderMap, HeaderName}},
};
// `()` gives an empty response
async fn empty() {}
// String will get a `text/plain; charset=utf-8` content-type
async fn plain_text(uri: Uri) -> String {
format!("Hi from {}", uri.path())
}
// Bytes will get a `application/octet-stream` content-type
async fn bytes() -> Vec<u8> {
vec![1, 2, 3, 4]
}
// `Json` will get a `application/json` content-type and work with anything that implements `serde::Serialize`
async fn json() -> Json<Vec<String>> {
Json(vec!["foo".to_owned(), "bar".to_owned()])
}
// `Html` will get a `text/html` content-type
async fn html() -> Html<&'static str> {
Html("<p>Hello, World!</p>")
}
// `StatusCode` gives an empty response with that status code
async fn status() -> StatusCode {
StatusCode::NOT_FOUND
}
// `HeaderMap` gives an empty response with some headers
async fn headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(header::SERVER, "axum".parse().unwrap());
headers
}
// An array of tuples also gives headers
async fn array_headers() -> [(HeaderName, &'static str); 2] {
[
(header::SERVER, "axum"),
(header::CONTENT_TYPE, "text/plain")
]
}
// Use `impl IntoResponse` to avoid writing the whole type
async fn impl_trait() -> impl IntoResponse {
[
(header::SERVER, "axum"),
(header::CONTENT_TYPE, "text/plain")
]
}
// `(StatusCode, impl IntoResponse)` will override the status code of the response
async fn with_status(uri: Uri) -> (StatusCode, String) {
(StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path()))
}
// Use `impl IntoResponse` to avoid having to type the whole type
async fn impl_trait(uri: Uri) -> impl IntoResponse {
(StatusCode::NOT_FOUND, format!("Not Found: {}", uri.path()))
}
// `(HeaderMap, impl IntoResponse)` to add additional headers
async fn with_headers() -> impl IntoResponse {
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "text/plain".parse().unwrap());
(headers, "foo")
}
// Or an array of tuples to more easily build the headers
async fn with_array_headers() -> impl IntoResponse {
([(header::CONTENT_TYPE, "text/plain")], "foo")
}
// Use string keys for custom headers
async fn with_array_headers_custom() -> impl IntoResponse {
([("x-custom", "custom")], "foo")
}
// `(StatusCode, headers, impl IntoResponse)` to set status and add headers
// `headers` can be either a `HeaderMap` or an array of tuples
async fn with_status_and_array_headers() -> impl IntoResponse {
(
StatusCode::NOT_FOUND,
[(header::CONTENT_TYPE, "text/plain")],
"foo",
)
}
// `(Extension<_>, impl IntoResponse)` to set response extensions
async fn with_status_extensions() -> impl IntoResponse {
(
Extension(Foo("foo")),
"foo",
)
}
#[derive(Clone)]
struct Foo(&'static str);
// Or mix and match all the things
async fn all_the_things(uri: Uri) -> impl IntoResponse {
let mut header_map = HeaderMap::new();
if uri.path() == "/" {
header_map.insert(header::SERVER, "axum".parse().unwrap());
}
(
// set status code
StatusCode::NOT_FOUND,
// headers with an array
[("x-custom", "custom")],
// some extensions
Extension(Foo("foo")),
Extension(Foo("bar")),
// more headers, built dynamically
header_map,
// and finally the body
"foo",
)
}
虽然文档没提, 任何实现 IntoResponse 的类型也实现了 Handler ,故可以作为 MethodRouter 的各方法的参数:
mod private {
// Marker type for `impl<T: IntoResponse> Handler for T`
#[allow(missing_debug_implementations)]
pub enum IntoResponseHandler {}
}
impl<T, S> Handler<private::IntoResponseHandler, S> for T
where
T: IntoResponse + Clone + Send + 'static,
{
type Future = std::future::Ready<Response>;
fn call(self, _req: Request, _state: S) -> Self::Future {
std::future::ready(self.into_response())
}
}
// 示例: 使用 tuple 来作为 Handler:
use axum::{
Router,
routing::{get, post},
Json,
http::StatusCode,
};
use serde_json::json;
let app = Router::new()
// &str 实现了 IntoResponse
.route("/", get("Hello, World!"))
.route("/users", post(
// tuple 的成员都实现了 IntoResponse,所以 tuple 也实现了 IntoResponse
(StatusCode::CREATED, Json(json!({ "id": 1, "username": "alice" })),)));
extractor #
extractor 是实现了 FromRequest/FromRequestParts trait 的类型值,可作为 Handler 闭包的输入,用于从请求中提取相关信息。
对于有多个 extractor 输入的 Handler,如 FnOnce(T1, T2, T3) ,前面的参数必须实现 FromRequestParts trait ,最后一个参数实现 FromRequest trait。
- 因为
FromRequestParts不消耗body,而FromRequest消耗body。
use axum::{
// Request/Json/Path/Extension/Query 均是 extractor
extract::{Request, Json, Path, Extension, Query},
routing::post,
http::header::HeaderMap,
body::{Bytes, Body},
Router,
};
use serde_json::Value;
use serde::Deserialize;
use std::collections::HashMap;
#[derive(Deserialize)]
struct CreateUser {
email: String,
password: String,
}
// 函数传参本质上是模式匹配赋值,所以类似于 Path((user_id, user_name)) 的 user_id/user_name 是解构后的内容。
// Path 从请求中提取路径字段(多个字段用 tuple 表示)
async fn path(Path(user_id): Path<u32>) {}
async fn path(Path((user_id, user_name)): Path<(u32, String)>) {}
// Query 从请求参数中生成对应类型
async fn query(Query(params): Query<HashMap<String, String>>) {}
// HeaderMap 包含所有请求 HTTP Headers
async fn headers(headers: HeaderMap) {}
// String 包含请求 body 的内容,确保是有效的 UTF-8
async fn string(body: String) {}
// Bytes 包含 raw 请求 Body 的内容
async fn bytes(body: Bytes) {}
// Json 将请求 body 反序列化为对应类型值(通用的为 serde_json::Value 类型)
async fn json(Json(payload): Json<Value>) {}
// Json 既实现了 Extractor 的 FromRequest trait,也实现了 IntoResponse,所以可以作为 Handler 返回值。
async fn create_user(Json(payload): Json<CreateUser>) {}
// Request 返回整个请求类型
async fn request(request: Request) {}
// Extension 从 request 的 extensions 中提取共享的 state
async fn extension(Extension(state): Extension<State>) {}
#[derive(Clone)]
struct State { /* ... */ }
let app = Router::new()
.route("/path/:user_id/:user_name", post(path))
.route("/query", post(query))
.route("/string", post(string))
.route("/bytes", post(bytes))
.route("/json", post(json))
.route("/request", post(request))
.route("/extension", post(extension)
.route("/users", post(create_user));
Handler 闭包函数的各 extractor 函数参数默认都是必须的, 否则报错,可以使用 Option 来指定可选参数:
use axum::{
extract::{Path, Query},
routing::get,
Router,
};
use uuid::Uuid;
use serde::Deserialize;
let app = Router::new().route("/users/:id/things", get(get_user_things));
#[derive(Deserialize)]
struct Pagination {
page: usize,
per_page: usize,
}
impl Default for Pagination {
fn default() -> Self {
Self { page: 1, per_page: 30 }
}
}
// Handler 可以同时使用多个 extractors
async fn get_user_things(
// 必须的参数
Path(user_id): Path<Uuid>,
// 可选的参数:从请求参数中构造(反序列化)为 Pagination 类型对象
pagination: Option<Query<Pagination>>,
) {
let Query(pagination) = pagination.unwrap_or_default();
// ...
}
获得 extractor 出错原因 #
使用 Result 获得 extractor 出错原因,错误类型是 extractor 实现 FromRequestParts/FromRequest 时指定的类型 Rejection 关联类型:
use axum::{
extract::{Json, rejection::JsonRejection},
routing::post,
Router,
};
use serde_json::Value; // 通用的 json descrialize 后的类型
// 每种 Extractor 都定义了自己的 Rejection 类型,返回的 Result::Err 为对应类型值。
async fn create_user(payload: Result<Json<Value>, JsonRejection>) {
match payload {
Ok(payload) => {
// We got a valid JSON payload
}
Err(JsonRejection::MissingJsonContentType(_)) => {
// Request didn't have `Content-Type: application/json` header
}
Err(JsonRejection::JsonDataError(_)) => {
// Couldn't deserialize the body into the target type
}
Err(JsonRejection::JsonSyntaxError(_)) => {
// Syntax error in the body
}
Err(JsonRejection::BytesRejection(_)) => {
// Failed to extract the request body
}
Err(_) => {
// `JsonRejection` is marked `#[non_exhaustive]` so match must include a catch-all case.
}
}
}
let app = Router::new().route("/users", post(create_user));
更复杂的获取 extractor 出错信息的例子:
use std::error::Error;
use axum::{
extract::{Json, rejection::JsonRejection},
response::IntoResponse,
http::StatusCode,
};
use serde_json::{json, Value};
async fn handler(result: Result<Json<Value>, JsonRejection>,) -> Result<Json<Value>, (StatusCode, String)> {
match result {
// if the client sent valid JSON then we're good
Ok(Json(payload)) => Ok(Json(json!({ "payload": payload }))),
Err(err) => match err {
JsonRejection::JsonDataError(err) => {
Err(serde_json_error_response(err))
}
JsonRejection::JsonSyntaxError(err) => {
Err(serde_json_error_response(err))
}
// handle other rejections from the `Json` extractor
JsonRejection::MissingJsonContentType(_) => Err((
StatusCode::BAD_REQUEST,
"Missing `Content-Type: application/json` header".to_string(),
)),
JsonRejection::BytesRejection(_) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to buffer request body".to_string(),
)),
// we must provide a catch-all case since `JsonRejection` is marked
// `#[non_exhaustive]`
_ => Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Unknown error".to_string(),
)),
},
}
}
// attempt to extract the inner `serde_path_to_error::Error<serde_json::Error>`, if that succeeds we
// can provide a more specific error.
//
// `Json` uses `serde_path_to_error` so the error will be wrapped in `serde_path_to_error::Error`.
fn serde_json_error_response<E>(err: E) -> (StatusCode, String) where E: Error + 'static,
{
if let Some(err) = find_error_source::<serde_path_to_error::Error<serde_json::Error>>(&err) {
let serde_json_err = err.inner();
(
StatusCode::BAD_REQUEST,
format!(
"Invalid JSON at line {} column {}",
serde_json_err.line(),
serde_json_err.column()
),
)
} else {
(StatusCode::BAD_REQUEST, "Unknown error".to_string())
}
}
// attempt to downcast `err` into a `T` and if that fails recursively try and downcast `err`'s source
fn find_error_source<'a, T>(err: &'a (dyn Error + 'static)) -> Option<&'a T>
where
T: Error + 'static,
{
if let Some(err) = err.downcast_ref::<T>() {
Some(err)
} else if let Some(source) = err.source() {
find_error_source(source)
} else {
None
}
}
自定义 extractor #
通过实现 FromRequestParts/FromRequest trait,可以自定义 extractor。
实现 FromRequestParts trait :
use axum::{
async_trait,
extract::FromRequestParts,
routing::get,
Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
request::Parts,
},
};
// Extractor 惯例是 struct tuple 类型(一般是泛型类型)
struct ExtractUserAgent(HeaderValue);
// S 为 State 类型
#[async_trait]
impl<S> FromRequestParts<S> for ExtractUserAgent where S: Send + Sync,
{
// 如果 extract 失败,返回的错误值类型。
type Rejection = (StatusCode, &'static str);
// 函数返回 Result,包含错误拒绝的值。
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if let Some(user_agent) = parts.headers.get(USER_AGENT) {
Ok(ExtractUserAgent(user_agent.clone()))
} else {
Err((StatusCode::BAD_REQUEST, "`User-Agent` header is missing"))
}
}
}
// 函数传参本质上是模式匹配赋值,所以 user_agent 包含解构后的内容。
async fn handler(ExtractUserAgent(user_agent): ExtractUserAgent) {
// 使用 user_agent 值
}
let app = Router::new().route("/foo", get(handler));
实现 FromRequest:
use axum::{
async_trait,
extract::{Request, FromRequest},
response::{Response, IntoResponse},
body::{Bytes, Body},
routing::get,
Router,
http::{
StatusCode,
header::{HeaderValue, USER_AGENT},
},
};
struct ValidatedBody(Bytes);
#[async_trait]
impl<S> FromRequest<S> for ValidatedBody where Bytes: FromRequest<S>, S: Send + Sync,
{
type Rejection = Response;
// 提取出错时返回 Rejection 类型
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let body = Bytes::from_request(req, state)
.await
.map_err(IntoResponse::into_response)?;
// do validation...
Ok(Self(body))
}
}
async fn handler(ValidatedBody(body): ValidatedBody) {
// 使用 body 数据
}
let app = Router::new().route("/foo", get(handler));
一个自定义类型只能实现其中一个 trait ,除非该 extractor 是其它 extractor 的包装器,这是通过对自定义类型的限界来实现的。
use axum::{
Router,
body::Body,
routing::get,
extract::{Request, FromRequest, FromRequestParts},
http::{HeaderMap, request::Parts},
async_trait,
};
use std::time::{Instant, Duration};
// an extractor that wraps another and measures how long time it takes to run
struct Timing<E> {
extractor: E,
duration: Duration,
}
// we must implement both `FromRequestParts`
#[async_trait]
impl<S, T> FromRequestParts<S> for Timing<T> where S: Send + Sync, T: FromRequestParts<S>,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request_parts(parts, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
// and `FromRequest`
#[async_trait]
impl<S, T> FromRequest<S> for Timing<T> where S: Send + Sync, T: FromRequest<S>,
{
type Rejection = T::Rejection;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
let start = Instant::now();
let extractor = T::from_request(req, state).await?;
let duration = start.elapsed();
Ok(Timing {
extractor,
duration,
})
}
}
async fn handler(
// this uses the `FromRequestParts` impl
_: Timing<HeaderMap>,
// this uses the `FromRequest` impl
_: Timing<String>,
) {}
extractor 提取请求 body 时最大限制为 2MB,可使用 DefaultBodyLimit::max(size) 来自定义:
use axum::{
Router,
routing::post,
body::Body,
extract::{Request, DefaultBodyLimit},
};
let app = Router::new()
.route("/", post(|request: Request| async {}))
// change the default limit
.layer(DefaultBodyLimit::max(1024));
如果要记录 extraction rejection 的日志,需要开启 axum 的 tracing feature(默认开启)和设置环境变量 RUST_LOG:
RUST_LOG=info,axum::rejection=trace
常用 extractor 类型 #
axum::extract module 提供了一些常用的 extractor 类型:
- JSON
- Form
- Request
- HeaderMap
- Extension:可以用于提取 with_state() 传入的 State 类型值。
- ConnectInfo
- Host
- MatchedPath
- MultiPart
- NestedPath
- OriginalUrl
- Path
- Query
- RawForm
- RawPathParams
- RawQuery
- State
JSON #
实现了 FromRequest/IntoResponse trait,可以作为 Handler 的输入和输出类型:
pub struct Json<T>(pub T);
// Extractor example
use axum::{
extract,
routing::post,
Router,
};
use serde::Deserialize;
#[derive(Deserialize)]
struct CreateUser {
email: String,
password: String,
}
// 指定 payload 为 Json<serde_json::Value> 类型时,可以反序列化为通用的 JSON 类型值。
async fn create_user(extract::Json(payload): extract::Json<CreateUser>) {
// payload is a `CreateUser`
}
let app = Router::new().route("/users", post(create_user));
// Response example
use axum::{
extract::Path,
routing::get,
Router,
Json,
};
use serde::Serialize;
use uuid::Uuid;
#[derive(Serialize)]
struct User {
id: Uuid,
username: String,
}
async fn get_user(Path(user_id) : Path<Uuid>) -> Json<User> {
let user = find_user(user_id).await;
Json(user)
}
async fn find_user(user_id: Uuid) -> User {
// ...
}
let app = Router::new().route("/users/:id", get(get_user));
Form #
实现了 FromRequest/IntoResponse trait,可以作为 Handler 的输入和输出类型:
// 实现了 FromRequest
pub struct Form<T>(pub T);
use axum::Form;
use serde::Deserialize;
#[derive(Deserialize)]
struct SignUp {
username: String,
password: String,
}
async fn accept_form(Form(sign_up): Form<SignUp>) {
// ...
}
// Response
use axum::Form;
use serde::Serialize;
#[derive(Serialize)]
struct Payload {
value: String,
}
async fn handler() -> Form<Payload> {
Form(Payload { value: "foo".to_owned() })
}
Request #
Request 返回整个 Request 对象,具有最大化的控制能力:
async fn request(request: Request) {}
HeaderMap #
包含所有的 Header
async fn headers(headers: HeaderMap) {}
Extension #
Extension 实现了 FromRequestParts/Layer<S>/IntoResponse, 所以可以作为 extractor、layer middleware 和 response 数据类型;
Extension 经常在开发 layer middleware 时使用 http request extension 来向 handler 传递数据。
// 作为 extractor:常用于 handers 间的共享 state 传递
use axum::{
Router,
Extension,
routing::get,
};
use std::sync::Arc;
// Some shared state used throughout our application
struct State {
// ...
}
// 因为 State 需要在所有 Handler 里共享(需要实现 Clone),所以不能使用 &mut ,故一般需要使用 Arc。
async fn handler(state: Extension<Arc<State>>) {
// ...
}
let state = Arc::new(State { /* ... */ });
let app = Router::new().route("/", get(handler))
.layer(Extension(state)); // Router 级别,适用于它的所有 Handler
// 作为 Response:可以在响应中添加 Extension 数据
use axum::{
Extension,
response::IntoResponse,
};
async fn handler() -> (Extension<Foo>, &'static str) {
(
Extension(Foo("foo")),
"Hello, World!"
)
}
#[derive(Clone)]
struct Foo(&'static str);
// Passing state from middleware to handlers
// State can be passed from middleware to handlers using request extensions:
use axum::{
Router,
http::StatusCode,
routing::get,
response::{IntoResponse, Response},
middleware::{self, Next},
extract::{Request, Extension},
};
#[derive(Clone)]
struct CurrentUser { /* ... */ }
async fn auth(mut req: Request, next: Next) -> Result<Response, StatusCode> {
let auth_header = req.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let auth_header = if let Some(auth_header) = auth_header {
auth_header
} else {
return Err(StatusCode::UNAUTHORIZED);
};
if let Some(current_user) = authorize_current_user(auth_header).await {
// insert the current user into a request extension so the handler can extract it
req.extensions_mut().insert(current_user);
Ok(next.run(req).await)
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
async fn authorize_current_user(auth_token: &str) -> Option<CurrentUser> {
// ...
}
async fn handler(
// extract the current user, set by the middleware
Extension(current_user): Extension<CurrentUser>,
) {
// ...
}
let app = Router::new()
.route("/", get(handler))
// Layer 先于 Handler 被执行,而且按照添加的逆序来被执行。
.route_layer(middleware::from_fn(auth));
ConnectInfo #
提取 client 请求信息, 需要和 Router.into_make_service_with_connect_info() 连用。
use axum::{
extract::connect_info::{ConnectInfo, Connected},
routing::get,
serve::IncomingStream,
Router,
};
let app = Router::new().route("/", get(handler));
async fn handler(
ConnectInfo(my_connect_info): ConnectInfo<MyConnectInfo>,
) -> String {
format!("Hello {my_connect_info:?}")
}
通过实现 Connected<IncomingStream<'_>> trait,也可以自定义 ConnectInfo 的泛型类型:
#[derive(Clone, Debug)]
struct MyConnectInfo {
// ...
}
impl Connected<IncomingStream<'_>> for MyConnectInfo {
fn connect_info(target: IncomingStream<'_>) -> Self {
MyConnectInfo {
// ...
}
}
}
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app.into_make_service_with_connect_info::<MyConnectInfo>()).await.unwrap();
DefaultBodyLimit #
DefaultBodyLimit 实现了 Layer,用于限制请求 body 大小(其实并不是 extractor)。
use axum::{
Router,
routing::post,
body::Body,
extract::{Request, DefaultBodyLimit},
};
let app = Router::new()
// change the default limit
.layer(DefaultBodyLimit::max(1024))
// this route has a different limit
.route("/", post(|request: Request| async {}).layer(DefaultBodyLimit::max(1024)))
// this route still has the default limit
.route("/foo", post(|request: Request| async {}));
Host #
用于解析请求主机名的提取器(extractor),实现了 FromRequestParts。
主机名按以下顺序解析:
- Forwarded 头
- X-Forwarded-Host 头
- Host 头
- 请求目标 / URI
MatchedPath #
返回请求匹配的路由路径 extractor, 实现了 FromRequestParts,返回的 path 为 Router 原始路径字符串。
use {
Router,
extract::MatchedPath,
routing::get,
};
let app = Router::new().route(
"/users/:id",
get(|path: MatchedPath| async move {
let path = path.as_str();
// `path` will be "/users/:id"
})
);
Multipart #
解析 multipart/form-data 请求(通常用于文件上传)的 extractor, 实现了 FromRequest,所以只能作为 handler 函数最后一个参数。
use axum::{
extract::Multipart,
routing::post,
Router,
};
use futures_util::stream::StreamExt;
async fn upload(mut multipart: Multipart) {
while let Some(mut field) = multipart.next_field().await.unwrap() {
let name = field.name().unwrap().to_string();
let data = field.bytes().await.unwrap();
println!("Length of `{}` is {} bytes", name, data.len());
}
}
let app = Router::new().route("/upload", post(upload));
multipart 表单支持 text 和 file 类型的 field,对于 text 类型使用 form 表单的方式来编码。
- multipart 表单,同时包含 text 和 file field:
<!doctype html>
<html>
<head>
<title>File Upload</title>
</head>
<body>
<form action="http://localhost:8000" method="post" enctype="multipart/form-data">
<p><input type="text" name="text" value="text default">
<p><input type="file" name="file1">
<p><input type="file" name="file2">
<p><button type="submit">Submit</button>
</form>
</body>
</html>
POST 上传时 HTTP BODY 内容示例如下:
- Content-Type: multipart/form-data; boundary=—————————9051914041544843365972754266
POST / HTTP/1.1
Host: localhost:8000
User-Agent: Mozilla/5.0 (X11; Ubuntu; Linux i686; rv:29.0) Gecko/20100101 Firefox/29.0
Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
Accept-Language: en-US,en;q=0.5
Accept-Encoding: gzip, deflate
Cookie: __atuvc=34%7C7; permanent=0; _gitlab_session=226ad8a0be43681acf38c2fab9497240; __profilin=p%3Dt; request_method=GET
Connection: keep-alive
Content-Type: multipart/form-data; boundary=---------------------------9051914041544843365972754266
Content-Length: 554
-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="text"
text default
-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="file1"; filename="a.txt"
Content-Type: text/plain
Content of a.txt.
-----------------------------9051914041544843365972754266
Content-Disposition: form-data; name="file2"; filename="a.html"
Content-Type: text/html
<!DOCTYPE html><title>Content of a.html.</title>
NestedPath #
获得匹配 route 的嵌套 path,实现了 FromRequestParts:
use axum::{
Router,
extract::NestedPath,
routing::get,
};
let api = Router::new().route(
"/users",
get(|path: NestedPath| async move {
// `path` will be "/api" because thats what this router is nested at when we build `app`
let path = path.as_str();
})
);
let app = Router::new().nest("/api", api);
OriginalUri #
获得请求的原始 URI ,而不管是否嵌套:
use axum::{
routing::get,
Router,
extract::OriginalUri,
http::Uri
};
let api_routes = Router::new()
.route(
"/users",
get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
// `uri` is `/users`
// `original_uri` is `/api/users`
}),
);
let app = Router::new().nest("/api", api_routes);
Path #
将请求 path 中的参数使用 serde 进行反序列化,多个路径参数用 tuple 来表示:
use axum::{
extract::Path,
routing::get,
Router,
};
use uuid::Uuid;
async fn users_teams_show(
Path((user_id, team_id)): Path<(Uuid, Uuid)>,
) {
// ...
}
let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));
Query #
将请求参数 Deserialize 到一个 struct type。(如果参数是可选的,需要使用 Option 类型)
use axum::{
extract::Query,
routing::get,
Router,
};
use serde::Deserialize;
#[derive(Deserialize)]
struct Pagination {
page: usize,
per_page: usize,
}
// This will parse query strings like `?page=2&per_page=30` into `Pagination` structs.
async fn list_things(pagination: Query<Pagination>) {
let pagination: Pagination = pagination.0;
// ...
}
let app = Router::new().route("/list_things", get(list_things));
RawForm #
Form 的类型需要实现 serde::Deserialize trait,而 RawForm 则不会对请求 body 进行反序列化,直接提取原始的 form 数据字符串。
use axum::{
extract::RawForm,
routing::get,
Router
};
async fn handler(RawForm(form): RawForm) {}
let app = Router::new().route("/", get(handler));
RawPathParams #
从 URL 中提取路径参数,但是不进行反序列化:
pub struct RawPathParams(/* private fields */);
impl<'a> IntoIterator for &'a RawPathParams
type Item = (&'a str, &'a str)
use axum::{
extract::RawPathParams,
routing::get,
Router,
};
async fn users_teams_show(params: RawPathParams) {
for (key, value) in ¶ms {
println!("{key:?} = {value:?}");
}
}
let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));
RawQuery #
从请求中提取原始的查询字符串:
pub struct RawQuery(pub Option<String>);
// Extractor that extracts the raw query string, without parsing it.
// Example
use axum::{
extract::RawQuery,
routing::get,
Router,
};
use futures_util::StreamExt;
async fn handler(RawQuery(query): RawQuery) {
// ...
}
let app = Router::new().route("/users", get(handler));
State #
State 是 Router 中使用的全局对象,不能使用 &mut,需要实现 Clone,所以一般使用 Arc<Mutex<_>> 具有内部可变性的类型:
use axum::{Router, routing::get, extract::State};
// the application state
//
// here you can put configuration, database connection pools, or whatever state you need
//
// see "When states need to implement `Clone`" for more details on why we need `#[derive(Clone)]` here.
#[derive(Clone)]
struct AppState {}
let state = AppState {};
// create a `Router` that holds our state
let app = Router::new()
.route("/", get(handler))
// provide the state so the router can access it
.with_state(state);
async fn handler(
// access the state via the `State` extractor
// extracting a state of the wrong type results in a compile error
State(state): State<AppState>,
) {
// use `state`...
}
substate #
State 只能使用一种类型,但是可以使用 FromRef 来提取“子状态”:
use axum::{Router, routing::get, extract::{State, FromRef}};
// the application state
#[derive(Clone)]
struct AppState {
// that holds some api specific state
api_state: ApiState,
}
// the api specific state
#[derive(Clone)]
struct ApiState {}
// support converting an `AppState` in an `ApiState`
impl FromRef<AppState> for ApiState {
fn from_ref(app_state: &AppState) -> ApiState {
app_state.api_state.clone()
}
}
let state = AppState {
api_state: ApiState {},
};
let app = Router::new()
.route("/", get(handler))
.route("/api/users", get(api_users))
.with_state(state);
async fn api_users(
// access the api specific state
State(api_state): State<ApiState>, // 自动使用 ApiState 实现的 FromRef 来将 AppState 转换为 ApiState
) {
}
async fn handler(
// we can still access to top level state
State(state): State<AppState>,
) {
}
WebSocketUpgrade #
为已经建立连接的客户端升级为 WebSocket 连接,实现了 FromRequestParts:
use axum::{
extract::ws::{WebSocketUpgrade, WebSocket},
routing::get,
response::{IntoResponse, Response},
Router,
};
let app = Router::new().route("/ws", get(handler));
async fn handler(ws: WebSocketUpgrade) -> Response {
ws.on_upgrade(handle_socket)
}
async fn handle_socket(mut socket: WebSocket) {
while let Some(msg) = socket.recv().await {
let msg = if let Ok(msg) = msg {
msg
} else {
// client disconnected
return;
};
if socket.send(msg).await.is_err() {
// client disconnected
return;
}
}
}
// If you need to read and write concurrently from a WebSocket you can use StreamExt::split:
use axum::{Error, extract::ws::{WebSocket, Message}};
use futures_util::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
async fn handle_socket(mut socket: WebSocket) {
let (mut sender, mut receiver) = socket.split();
tokio::spawn(write(sender));
tokio::spawn(read(receiver));
}
async fn read(receiver: SplitStream<WebSocket>) {
// ...
}
async fn write(sender: SplitSink<WebSocket, Message>) {
// ...
}
完整的 websocket 示例:
// https://docs.shuttle.rs/examples/axum-websockets
use std::{sync::Arc, time::Duration};
use axum::{
extract::{
ws::{Message, WebSocket},
WebSocketUpgrade,
},
response::IntoResponse,
routing::get,
Extension, Router,
};
use chrono::{DateTime, Utc};
use futures::{SinkExt, StreamExt};
use serde::Serialize;
use shuttle_axum::ShuttleAxum;
use tokio::{
sync::{watch, Mutex},
time::sleep,
};
use tower_http::services::ServeDir;
struct State {
clients_count: usize,
rx: watch::Receiver<Message>,
}
const PAUSE_SECS: u64 = 15;
const STATUS_URI: &str = "https://api.shuttle.rs";
#[derive(Serialize)]
struct Response {
clients_count: usize,
#[serde(rename = "dateTime")]
date_time: DateTime<Utc>,
is_up: bool,
}
#[shuttle_runtime::main]
async fn axum() -> ShuttleAxum {
let (tx, rx) = watch::channel(Message::Text("{}".to_string()));
let state = Arc::new(Mutex::new(State {
clients_count: 0,
rx,
}));
// Spawn a thread to continually check the status of the api
let state_send = state.clone();
tokio::spawn(async move {
let duration = Duration::from_secs(PAUSE_SECS);
loop {
let is_up = reqwest::get(STATUS_URI).await;
let is_up = is_up.is_ok();
let response = Response {
clients_count: state_send.lock().await.clients_count,
date_time: Utc::now(),
is_up,
};
let msg = serde_json::to_string(&response).unwrap();
if tx.send(Message::Text(msg)).is_err() {
break;
}
sleep(duration).await;
}
});
let router = Router::new()
.route("/websocket", get(websocket_handler))
.nest_service("/", ServeDir::new("static"))
.layer(Extension(state));
Ok(router.into())
}
async fn websocket_handler(ws: WebSocketUpgrade, Extension(state): Extension<Arc<Mutex<State>>>,) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket(socket, state))
}
async fn websocket(stream: WebSocket, state: Arc<Mutex<State>>) {
// By splitting we can send and receive at the same time.
let (mut sender, mut receiver) = stream.split();
let mut rx = {
let mut state = state.lock().await;
state.clients_count += 1;
state.rx.clone()
};
// This task will receive watch messages and forward it to this connected client.
let mut send_task = tokio::spawn(async move {
while let Ok(()) = rx.changed().await {
let msg = rx.borrow().clone();
if sender.send(msg).await.is_err() {
break;
}
}
});
// This task will receive messages from this client.
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(text))) = receiver.next().await {
println!("this example does not read any messages, but got: {text}");
}
});
// If any one of the tasks exit, abort the other.
tokio::select! {
_ = (&mut send_task) => recv_task.abort(),
_ = (&mut recv_task) => send_task.abort(),
};
// This client disconnected
state.lock().await.clients_count -= 1;
}
state #
State 可以为 Handler 共享全局数据或状态,如数据库连接池对象或其它 Client 等。
三种共享 State 的方式:
- 使用 State extractor;
- 使用 Extensions extractor;
- 使用闭包捕获机制;
State 的限界是 Clone + Send + Sync + 'static, 要获得所有权和支持多线程环境,所以一般使用 Arc<Mutex<>> 包裹的支持内部可变性的对象类型。
Router<S> 的 S 泛型参数表示 State 类型:
pub struct Router<S = ()> { /* private fields */ }
// Router<S> 的方法,S 需要满足 Clone + Send + Sync + 'static
impl<S> Router<S> where S: Clone + Send + Sync + 'static
// with_state<S2> 是泛型方法,如果未显式指定 S2 类型,则 Rust 根据上下文自动推断,
// 例如,后续调用该 Router 的 into_make_service() 方法时,S2 自动推断为 ();
pub fn with_state<S2>(self, state: S) -> Router<S2>
// Router 等效于 Router<S=()>,下面两个方法只在 Router<()> 类型上定义
impl Router
pub fn into_make_service(self) -> IntoMakeService<Self>
pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C>
State Extractor #
可以在 Router、MethodRouter 和 Handler 三个层次上,通过 with_state() 方法添加 State 值:
// Router:
// pub fn with_state<S2>(self, state: S) -> Router<S2>
use axum::{
extract::State,
routing::get,
Router,
};
use std::sync::Arc;
struct AppState {
// ...
}
let shared_state = Arc::new(AppState { /* ... */ });
let app = Router::new()
.route("/", get(handler))
.with_state(shared_state);
async fn handler(
State(state): State<Arc<AppState>>,
) {
// ...
}
当从函数返回 Router 时,一般建议不直接在函数内设置 State,而是分两步:
- 定义函数返回
Router<S>类型; - 调用该函数返回对象的
with_state()方法;
这是由于 Router::into_make_service() 是在 Router<()> 类型上实现。
下面的例子报错,是因为 Router<AppState> 没有提供 Router::into_make_service() 方法:
// This wont work because we're returning a `Router<AppState>` i.e. we're saying we're still missing an `AppState`
fn routes(state: AppState) -> Router<AppState> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}
// app 是 Router<AppState> 类型
let app = routes(AppState {});
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
// 出错的情况:We can only call `Router::into_make_service` on a `Router<()>` but `app` is a `Router<AppState>`
axum::serve(listener, app).await.unwrap();
解决办法:返回 Router<()> 类型:
// We've provided all the state necessary so return `Router<()>`
fn routes(state: AppState) -> Router<()> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}
// app 是 Router<()> 类型
let app = routes(AppState {});
// We can now call `Router::into_make_service`
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await.unwrap();
Router::with_state<S2>::() -> Router<S2> 返回新 Router<S2> 并不总是 Router<()> 类型 ,取决于后续使用返回的 Router<S2> 时编译器对 S2 的推断。
例如,如果对返回的对象调用 into_make_service() 方法,由于该方法是在 Router<()> 上实现的,所以 S2 推断为 ():
let router: Router<AppState> = Router::new()
.route("/", get(|_: State<AppState>| async {}));
// When we call `with_state` we're able to pick what the next missing state type is.
// Here we pick `String`.
let string_router: Router<String> = router.with_state(AppState {});
// That allows us to add new routes that uses `String` as the state type
let string_router = string_router
.route("/needs-string", get(|_: State<String>| async {}));
// Provide the `String` and choose `()` as the new missing state.
let final_router: Router<()> = string_router.with_state("foo".to_owned());
// Since we have a `Router<()>` we can run it.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, final_router).await.unwrap();
如果确实需要在返回 Router 的函数内设置 State,则返回的 State 不要加泛型参数或指定返回类型是 Router<()>:
// Don't return `Router<AppState>`
fn routes(state: AppState) -> Router {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}
let routes = routes(AppState {});
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();
如果函数返回的 Router 在 nest() 方法中使用,则函数返回的 Router 需要使用无限界的泛型参数,后续由编译器自动推断(如在 axum::serve() 中使用时,自动推断为 () 类型):
- 如果不加泛型参数 S, 则表示返回的是
Router<()>类型,可能与调用该方法的 Router 对象泛型参数类型不符。
// Router 的 nest() 方法传入的 router 对象的 S 类型和自身 S 一致。
impl<S> Router<S> where S: Clone + Send + Sync + 'static,
pub fn nest(self, path: &str, router: Router<S>) -> Self
// 示例:
fn routes<S>(state: AppState) -> Router<S> {
Router::new()
.route("/", get(|_: State<AppState>| async {}))
.with_state(state)
}
let routes = Router::new().nest("/api", routes(AppState {}));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, routes).await.unwrap();
例如:
// A router that _needs_ an `AppState` to handle requests
let router: Router<AppState> = Router::new().route("/", get(|_: State<AppState>| async {}));
// Once we call `Router::with_state` the router isn't missing the state anymore, because we just provided it
//
// Therefore the router type becomes `Router<()>`, i.e a router that is not missing any state
let router: Router<()> = router.with_state(AppState {});
// Only `Router<()>` has the `into_make_service` method.
//
// You cannot call `into_make_service` on a `Router<AppState>` because it is still missing an `AppState`.
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, router).await.unwrap();
性能优化:如果需要一个实现 Service 的 Router,但是不需要任何 State,也建议调用 .with_state(()) 方法,这样可以让 axum 更新内部状态,减少分配内存,从而提升性能:
use axum::{Router, routing::get};
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
// even though we don't need any state, call `with_state(())` anyway
.with_state(());
By 请求的 State:Extensions #
Router 层次添加的 State 可以被所有该 Router 的所有请求使用。
如果要根据 Request 来生成特定请求相关的 State,例如从中间件生成的认证授权数据,则需要使用 Extension。
Extension 是使用 http request extension 向 Handler 传递 state 的机制。
Extension 实现了 FromRequestParts 和 Layer 和 IntoResponse, 所以可以作为 extractor、layer middleware 和响应数据类型;
具体参考前文 Extension 章节。
使用闭包传递 State #
State 也可以直接被 Handler 闭包捕获:
use axum::{
Json,
extract::{Extension, Path},
routing::{get, post},
Router,
};
use std::sync::Arc;
use serde::Deserialize;
struct AppState {
// ...
}
let shared_state = Arc::new(AppState { /* ... */ });
let app = Router::new()
.route(
"/users",
post({
let shared_state = Arc::clone(&shared_state);
move |body| create_user(body, shared_state)
}),
)
.route(
"/users/:id",
get({
let shared_state = Arc::clone(&shared_state);
move |path| get_user(path, shared_state)
}),
);
async fn get_user(Path(user_id): Path<String>, state: Arc<AppState>) {
// ...
}
async fn create_user(Json(payload): Json<CreateUserPayload>, state: Arc<AppState>) {
// ...
}
#[derive(Deserialize)]
struct CreateUserPayload {
// ...
}
Layer 中间件 #
可以在 Router、MethodRouter 和 Handler 三个层次上,通过 layer()/route_layer() 方法添加 Layer 中间件:
- 整个 routers:
Router::layer()和Router::route_layer() - 单个 method router:
MethodRouter::layer()和MethodRouter::route_layer() - 单个 handler:
Handler::layer()
Layer 封装了一个 Service,并返回一个新的 Service:
pub trait Layer<S> {
type Service;
// S 为传入的 Service
fn layer(&self, inner: S) -> Self::Service;
}
Layer 是在 Handler 之前执行的,故一般用来实现中间件逻辑:
impl<S> Router<S> where S: Clone + Send + Sync + 'static
pub fn layer<L>(self, layer: L) -> Router<S>
where
// Route 实现了 Service trait,它的 Request 类型是 http::request::Request,响应类型是 http::response::Response
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static
pub fn route_layer<L>(self, layer: L) -> Self
where
L: Layer<Route> + Clone + Send + 'static,
L::Service: Service<Request> + Clone + Send + 'static,
<L::Service as Service<Request>>::Response: IntoResponse + 'static,
<L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request>>::Future: Send + 'static,
route_layer() 和 layer() 差异:前者只在 router 匹配时才执行 layer 逻辑,而后者无论如何都执行 layer 逻辑。
Route 实现了 Service trait,它的 Request 类型是 http::request::Request,响应类型是 http::response::Response:
// Route
impl<B, E> Service<Request<B>> for Route<E>
where
// B 是 Trait http_body::Body,其中 Data 是 bytes::Bytes
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>
// Responses given by the service. Struct http::response::Response 类型
type Response = Response<Body>
type Error = E
type Future = RouteFuture<E>
使用 Router.layer() 添加的 Layer 按相反的顺序被依次调用,最后调用 Handler, 所以可以在 layer() 中实现调用 Handler 前和后的处理能力:
use axum::{routing::get, Router};
async fn handler() {}
let app = Router::new()
.route("/", get(handler))
.layer(layer_one)
.layer(layer_two)
.layer(layer_three); // 最先被调用执行
// 用 layer() 添加的 layer 按相反的顺序被一次调用,最后调用 handler
requests
|
v
+----- layer_three -----+
| +---- layer_two ----+ |
| | +-- layer_one --+ | |
| | | | | |
| | | handler | | |
| | | | | |
| | +-- layer_one --+ | |
| +---- layer_two ----+ |
+----- layer_three -----+
|
v
responses
使用 tower::ServiceBuilder 的 layer() 添加的中间件,按照添加的顺序来执行。
- 建议使用
ServiceBuilder来创建含多个 Layer 的 Layer:
use tower::ServiceBuilder;
use axum::{routing::get, Router};
async fn handler() {}
let app = Router::new()
.route("/", get(handler))
.layer(
ServiceBuilder::new()
.layer(layer_one)
.layer(layer_two)
.layer(layer_three), // layer() 方法返回的是 Layer 对象
//.service(handerl) // 调用 service() 方法后,返回 Service 对象
);
tower_http crate 提供的 Layer #
TraceLayer:用于 tracing/logging;CorsLayer:用于处理 CORS;CompressionLayer:用于自动压缩响应;RequestIdLayer和PropagateRequestIdLayer:用于设置和传播 request ids;TimeoutLayer:用于超时控制;
use axum::{
routing::get,
Router,
};
use tower_http::validate_request::ValidateRequestHeaderLayer;
let app = Router::new().route(
"/foo",
get(|| async {})
.route_layer(ValidateRequestHeaderLayer::bearer("password"))
);
// RequestIdLayer 示例
use http::{Request, Response, header::HeaderName};
use tower::{Service, ServiceExt, ServiceBuilder};
use tower_http::request_id::{
SetRequestIdLayer, PropagateRequestIdLayer, MakeRequestId, RequestId,
};
use http_body_util::Full;
use bytes::Bytes;
use std::sync::{Arc, atomic::{AtomicU64, Ordering}};
// A `MakeRequestId` that increments an atomic counter
#[derive(Clone, Default)]
struct MyMakeRequestId {
counter: Arc<AtomicU64>,
}
impl MakeRequestId for MyMakeRequestId {
fn make_request_id<B>(&mut self, request: &Request<B>) -> Option<RequestId> {
let request_id = self.counter
.fetch_add(1, Ordering::SeqCst)
.to_string()
.parse()
.unwrap();
Some(RequestId::new(request_id))
}
}
let x_request_id = HeaderName::from_static("x-request-id");
let mut svc = ServiceBuilder::new()
// set `x-request-id` header on all requests
.layer(SetRequestIdLayer::new(
x_request_id.clone(),
MyMakeRequestId::default(),
))
// propagate `x-request-id` headers from request to response
.layer(PropagateRequestIdLayer::new(x_request_id))
.service(handler);
let request = Request::new(Full::default());
let response = svc.ready().await?.call(request).await?;
assert_eq!(response.headers()["x-request-id"], "0");
另一个例子: 在 tracing log 中记录 request ID:
use tower_http::{
ServiceBuilderExt,
trace::{TraceLayer, DefaultMakeSpan, DefaultOnResponse},
};
let svc = ServiceBuilder::new()
// make sure to set request ids before the request reaches `TraceLayer`
.set_x_request_id(MyMakeRequestId::default())
// log requests and responses
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_response(DefaultOnResponse::new().include_headers(true))
)
// propagate the header to the response before the response reaches `TraceLayer`
.propagate_x_request_id()
.service(handler);
使用 tower_http::ServiceBuilderExt 提供的方法,可以大大简化创建和使用这些中间件 Layer:
use tower_http::ServiceBuilderExt;
let mut svc = ServiceBuilder::new()
.set_x_request_id(MyMakeRequestId::default())
.propagate_x_request_id() // 这些简化方法返回 Layer
.service(handler); // 调用 service 方法后,返回 Service
let request = Request::new(Full::default());
let response = svc.ready().await?.call(request).await?;
assert_eq!(response.headers()["x-request-id"], "0");
创建自定义 Layer #
有 4 种方式:
axum::middleware::from_fn/from_fn_with_state: 使用闭包创建axum::middleware::from_extractor: 从现有的 Extractor 创建- tower 的 combinators,如:
- ServiceBuilder::map_request
- ServiceBuilder::map_response
- ServiceBuilder::then
- ServiceBuilder::and_then
tower::Service 和 Pin<Box<dyn Future>>: 最灵活(也是最复杂)的创建方式
使用 from_fn() #
使用 axum::middleware::from_fn()/from_fn_with_state() 创建 Layer, 对于传入的闭包 f 有如下要求:
- 是异步闭包函数类型;
- 传入 0 个或多个 FromRequestParts extractors,传入一个 FromRequest extractor,如 Request,作为倒数第二个参数;
- 使用 Next 作为最后一个参数;
- 返回值需要实现 IntoResponse;
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T>
// 实例:
use axum::{
Router,
http,
routing::get,
response::Response,
middleware::{self, Next},
extract::Request,
};
async fn my_middleware(request: Request, next: Next,) -> Response {
// do something with `request`...
let response = next.run(request).await;
// do something with `response`...
response
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(middleware::from_fn(my_middleware));
如果要中间件函数要使用 State,则使用 axum::middleware::from_fn_with_state() 来创建:
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T>
use axum::{
Router,
http::StatusCode,
routing::get,
response::{IntoResponse, Response},
middleware::{self, Next},
extract::{Request, State},
};
#[derive(Clone)]
struct AppState { /* ... */ }
async fn my_middleware(
State(state): State<AppState>,
// you can add more extractors here but the last extractor must implement `FromRequest` which `Request` does
request: Request,
next: Next,
) -> Response {
// do something with `request`...
let response = next.run(request).await;
// do something with `response`...
response
}
let state = AppState { /* ... */ };
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.route_layer(middleware::from_fn_with_state(state.clone(), my_middleware))
.with_state(state);
使用 from_extractor #
使用 axum::middleware::from_extractor()/from_extractor_with_state() 从一个 Extractor 类型创建中间件:
- 一般用于 validate 请求,可以复用已有的 extractor type 类型;
- 如果 extractor 执行成功则继续处理, 否则出错返回;
- 如果消耗 body, 则后续的 Router Service 在处理时获得的是空 body;
// E 是 Extractor 类型
pub fn from_extractor<E>() -> FromExtractorLayer<E, ()>
// 实例: 创建一个 RequireAuth extractor 来验证 Authorization 头部:
use axum::{
extract::FromRequestParts,
middleware::from_extractor,
routing::{get, post},
Router,
http::{header, StatusCode, request::Parts},
};
use async_trait::async_trait;
// An extractor that performs authorization.
struct RequireAuth;
#[async_trait]
impl<S> FromRequestParts<S> for RequireAuth where S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let auth_header = parts
.headers
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
Some(auth_header) if token_is_valid(auth_header) => {
Ok(Self)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
}
fn token_is_valid(token: &str) -> bool {}
async fn handler() {}
async fn other_handler() {}
let app = Router::new()
.route("/", get(handler))
.route("/foo", post(other_handler))
// The extractor will run before all routes
.route_layer(from_extractor::<RequireAuth>());
使用 map_request/map_response #
使用 tower crate 提供的一些工具函数来创建 Layer, 它们一般用来做简单的 request 或 response 转换,返回的类型实现了 Layer:
- map_request()
- map_request_with_state()
- map_response()
- map_response_with_state()
map_request() 用于对请求进行转换,它的参数是闭包 F,输入是 T1..T16 个 extractor,返回转换后的 Request 类型(如果为 Result 则值为 Err 时会提前拒绝请求):
- Request
- Result<Request, E> where E: IntoResponse
pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T>
pub struct MapRequestLayer<F, S, T> { /* private fields */ }
impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T> where F: Clone, S: Clone,
type Service = MapRequest<F, S, I, T>
fn layer(&self, inner: I) -> Self::Service
pub struct MapRequest<F, S, I, T> { /* private fields */ }
// axum 为闭包函数 F 实现了 Service trait:
// F 的参数数量从 T1 到 T16 不等,分别对应 1 到 16 个 extractor 参数。
impl<F, Fut, S, I, B, T1> Service<Request<B>> for MapRequest<F, S, I, (T1,)>
where
F: FnMut(T1) -> Fut + Clone + Send + 'static,
T1: FromRequest<S> + Send,
Fut: Future + Send + 'static,
Fut::Output: IntoMapRequestResult<B> + Send + 'static,
I: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
I::Response: IntoResponse,
I::Future: Send + 'static,
B: HttpBody<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
S: Clone + Send + Sync + 'static,
type Response = Response<Body>
type Error = Infallible
type Future = ResponseFuture
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>
fn call(&mut self, req: Request<B>) -> Self::Future
// 闭包函数 F 的返回值需要实现该 IntoMapRequestResult trait
pub trait IntoMapRequestResult<B>: Sealed<B> {
// Required method
fn into_map_request_result(self) -> Result<Request<B>, Response>;
}
// axum 为 Request<B> 和 Result<Request<B>, E> 实现了 IntoMapRequestResult trait
impl<B> IntoMapRequestResult<B> for Request<B>
impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E> where E: IntoResponse,
// 示例:
use axum::{
Router,
http::{Request, StatusCode},
routing::get,
middleware::map_request,
};
async fn set_header<B>(mut request: Request<B>) -> Request<B> {
request.headers_mut().insert("x-foo", "foo".parse().unwrap());
request
}
async fn handler<B>(request: Request<B>) {
// `request` will have an `x-foo` header
}
let app = Router::new()
.route("/", get(handler))
.layer(map_request(set_header));
// 返回 Result 类型的情况
async fn auth<B>(request: Request<B>) -> Result<Request<B>, StatusCode> {
let auth_header = request.headers()
.get(http::header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
match auth_header {
Some(auth_header) if token_is_valid(auth_header) => Ok(request),
_ => Err(StatusCode::UNAUTHORIZED),
}
}
fn token_is_valid(token: &str) -> bool {
// ...
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.route_layer(map_request(auth));
// map_request() 的函数的输入除了 Request 外,还可以包含其它 extractor
use axum::{
Router,
routing::get,
middleware::map_request,
extract::Path,
http::Request,
};
use std::collections::HashMap;
async fn log_path_params<B>(
Path(path_params): Path<HashMap<String, String>>,
request: Request<B>,
) -> Request<B> {
tracing::debug!(?path_params);
request
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(map_request(log_path_params));
map_response() 用于对响应进行转换:传入的闭包 F,输入是 T1..T16 个 extractor,返回转换后的响应类型(可以是任意实现 IntoResponse 的类型);
pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T>
pub struct MapResponseLayer<F, S, T> { /* private fields */ }
impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T> where F: Clone, S: Clone,
type Service = MapResponse<F, S, I, T>
fn layer(&self, inner: I) -> Self::Service
pub struct MapResponse<F, S, I, T> { /* private fields */ }
impl<F, Fut, S, I, B, ResBody> Service<Request<B>> for MapResponse<F, S, I, ()>
where
F: FnMut(Response<ResBody>) -> Fut + Clone + Send + 'static,
Fut: Future + Send + 'static,
Fut::Output: IntoResponse + Send + 'static,
I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible> + Clone + Send + 'static,
I::Future: Send + 'static,
B: Send + 'static,
ResBody: Send + 'static,
S: Clone + Send + Sync + 'static,
// 闭包 F 的参数数量从 T1 到 T16 不等,分别对应 1 到 16 个 extractor 参数。
impl<F, Fut, S, I, B, ResBody, T1, T2> Service<Request<B>> for MapResponse<F, S, I, (T1, T2)>
where
F: FnMut(T1, T2, Response<ResBody>) -> Fut + Clone + Send + 'static,
T1: FromRequestParts<S> + Send,
T2: FromRequestParts<S> + Send,
Fut: Future + Send + 'static,
Fut::Output: IntoResponse + Send + 'static,
I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible> + Clone + Send + 'static,
I::Future: Send + 'static,
B: Send + 'static,
ResBody: Send + 'static,
S: Clone + Send + Sync + 'static,
// 示例:
use axum::{
Router,
routing::get,
middleware::map_response,
response::Response,
};
async fn set_header<B>(mut response: Response<B>) -> Response<B> {
response.headers_mut().insert("x-foo", "foo".parse().unwrap());
response
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(map_response(set_header));
// 和 map_request() 类似,也可以使用 extractor 参数
use axum::{
Router,
routing::get,
middleware::map_response,
extract::Path,
response::Response,
};
use std::collections::HashMap;
async fn log_path_params<B>(
Path(path_params): Path<HashMap<String, String>>,
response: Response<B>,
) -> Response<B> {
tracing::debug!(?path_params);
response
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(map_response(log_path_params));
// 可以返回任何实现 IntoResponse trait 的类型对象
use axum::{
Router,
routing::get,
middleware::map_response,
response::{Response, IntoResponse},
};
use std::collections::HashMap;
async fn set_header(response: Response) -> impl IntoResponse {
(
[("x-foo", "foo")],
response,
)
}
let app = Router::new()
.route("/", get(|| async { /* ... */ }))
.layer(map_response(set_header));
使用 tower::Service 和 Pin<Box>
#
使用 tower::Service 和 Pin<Box<dyn Future>> 实现 Layer 及关联的 Service,具有最大的灵活性:
use axum::{
response::Response,
body::Body,
extract::Request,
};
use futures_util::future::BoxFuture;
use tower::{Service, Layer};
use std::task::{Context, Poll};
// 定义 Layer 时,一般还要为 Layer 定义一个关联的 Service
#[derive(Clone)]
struct MyLayer;
impl<S> Layer<S> for MyLayer {
type Service = MyMiddleware<S>; // 关联的 Service 类型
fn layer(&self, inner: S) -> Self::Service {
MyMiddleware { inner }
}
}
#[derive(Clone)]
struct MyMiddleware<S> {
inner: S,
}
// Service 的 request 必须是 http::request::Request 类型,response 必须是 http::response::Response, 才能满足 Routing.layer() 的要求。
impl<S> Service<Request> for MyMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
// `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let future = self.inner.call(request);
Box::pin(async move {
let response: Response = future.await?;
Ok(response)
})
}
}
error_handling #
对于 Result<T, E> 类型,如果 T 和 E 都实现 IntoResponse 时,则该 Result 也实现 IntoResponse:
// Result 也实现了 IntoResponse, Ok 和 Err 都要实现 IntoResponse
impl<T, E> IntoResponse for Result<T, E> where T: IntoResponse, E: IntoResponse
// axum::response::ErrorResponse 实现了 From<IntoResponse>, 可以从 IntoResponse 转换为自身的类型
impl<T> IntoResponse for Result<T, ErrorResponse> where T: IntoResponse
impl<T, U> IntoResponse for Chain<T, U> where
T: Buf + Unpin + Send + 'static,
U: Buf + Unpin + Send + 'static,
例如 axum::http::StatusCode 实现了 IntoResponse trait,故可以使用它作为 Result 的返回值:
use axum::http::StatusCode;
async fn handler() -> Result<String, StatusCode> {
// ...
}
如果要返回出错信息,需要自定义 Error type 并实现 IntoResponse;
// https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Router,
};
#[tokio::main]
async fn main() {
let app = Router::new().route("/", get(handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
// Handler 返回的 Result 会被 IntoResponse 转换为 Response 返回给 client
async fn handler() -> Result<(), AppError> {
try_thing()?;
Ok(())
}
fn try_thing() -> Result<(), anyhow::Error> {
anyhow::bail!("it failed!")
}
// Make our own error that wraps `anyhow::Error`.
struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
// axum 为 (StatusCode, String) 实现了 IntoRespose
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}
axum 在使用 Service 时,要求 Service 关联的 Error 类型为 Infallible,即不可能发生错误。
例如,Router.route_service() 使用 Service 作为请求的 Handler,该 Service 的约束是 Service<Request, Error = Infallible> + Clone + Send + 'static :
// struct Router 的 route_service() 方法的 service 约束:
pub fn route_service<T>(self, path: &str, service: T) -> Self
where
// Error 是 Infallible 类型,即不可能返回错误
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
// Response 必须实现 IntoRsponse
T::Response: IntoResponse,
T::Future: Send + 'static,
// 示例:
use axum::{
Router,
body::Body,
routing::{any_service, get_service},
extract::Request,
http::StatusCode,
error_handling::HandleErrorLayer,
};
use tower_http::services::ServeFile;
use http::Response;
use std::{convert::Infallible, io};
use tower::service_fn;
let app = Router::new()
.route(
"/",
// service_fn() 从闭包函数创建一个 Service,要求闭包函数的输入为 Request,输出为 Result。
// 为了满足 any_service 的要求,Result 的 Error 必须为 Infallible .
any_service(service_fn(|_: Request| async {
let res = Response::new(Body::from("Hi from `GET /`"));
Ok::<_, Infallible>(res)
}))
)
.route_service(
"/foo",
service_fn(|req: Request| async move {
let body = Body::from(format!("Hi from `{} /foo`", req.method()));
let res = Response::new(body);
Ok::<_, Infallible>(res)
})
)
.route_service(
"/static/Cargo.toml",
ServeFile::new("Cargo.toml"),
);
而一般通过 tower::service_fn(fn) 创建的 Service 约束是 FnMut(Request) -> Future<Output = Result<R, E>> ,其中 E 可能不是 Infallible 类型,导致类型不匹配。
为了能在 Router.route_service() 中使用 tower::service_fn(fn) 创建的返回 Result<R, E> 的 Service,axum 提供了 axum::error_handling::HandleError 来将它们转换为 Response。
HandleError 将返回 Result Error 的 Service 转换为 Service<Request, Error=Infallible> ,它是通过传入一个闭包来将 Result Err 转换为 IntoResponse:
use axum::{
Router,
body::Body,
http::{Request, Response, StatusCode},
error_handling::HandleError,
};
async fn thing_that_might_fail() -> Result<(), anyhow::Error> {
// ...
}
// 使用 service_fn() 将异步函数转换为实现 Service 的 ServiceFn 对象, 而 ServiceFn 在实现 Service 时对
// 于 fn 的定义是 FnMut(Request) -> Future<Output = Result<R, E>>, 所以 fn 可以返回包含 Error 的 Result
let some_fallible_service = tower::service_fn(|_req| async {
thing_that_might_fail().await?;
Ok::<_, anyhow::Error>(Response::new(Body::empty()))
});
// route_service() 输入的 Service 的约束是:Service<Request, Error = Infallible> + Clone + Send +
// 'static, 这里的 Error = Infallible,与 tower::service_fn() 返回的 Error = xxx 不匹配,所以需要
// HandleError::new() 来进行转换。
let app = Router::new().route_service(
"/",
HandleError::new(some_fallible_service, handle_anyhow_error),
);
// 将 err 转换为实现 IntoResponse 对象的类型
async fn handle_anyhow_error(err: anyhow::Error) -> (StatusCode, String) {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {err}"),
)
}
HandleError<S, F, T> 实现了 Service: F 对应的闭包函数的参数是可变的,但最后一个必须是 Error 类型,其它前面的参数类型需要实现 FromRequestParts<()> + Send,所以 F 闭包的参数还可以包含 Extractor:
pub struct HandleError<S, F, T> { /* private fields */ }
impl<S, F, T> HandleError<S, F, T>
pub fn new(inner: S, f: F) -> Self
impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send,
S::Error: Send,
S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
B: Send + 'static,
// F 闭包的参数数量是可变的,
impl<S, F, B, Res, Fut, T1> Service<Request<B>> for HandleError<S, F, (T1,)>
where
S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send,
S::Error: Send,
S::Future: Send,
F: FnOnce(T1, S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
T1: FromRequestParts<()> + Send,
B: Send + 'static
axum::error_handling::HandleErrorLayer 提供了处理 middleware Error 的转换能力。
HandleErrorLayer 实现了 Layer<S>/Service, 它内部使用 HandlerError 将传入的返回 Error 的 Service 转换为 axum 使用的 Service<Request, Error = Infallible>:
- new(f) 输入是闭包函数
FnOnce($($ty),*, S::Error) -> Fut + Clone + Send + 'static, 输入是 0 或多个 extractor + 最后一个 Error, 返回一个 IntoResponse 对象。
pub struct HandleErrorLayer<F, T> { /* private fields */ }
impl<F, T> HandleErrorLayer<F, T>
// f 闭包函数的输入是 Error,返回一个 IntoResponse 对象
pub fn new(f: F) -> Self
impl<S, F, T> Layer<S> for HandleErrorLayer<F, T>
where
F: Clone,
{
type Service = HandleError<S, F, T>;
fn layer(&self, inner: S) -> Self::Service {
// self.f 传给 HandlerError::new(), 所以需要满足它对 F 的要求
HandleError::new(inner, self.f.clone())
}
}
impl<S, F, B, Fut, Res> Service<Request<B>> for HandleError<S, F, ()>
where
S: Service<Request<B>> + Clone + Send + 'static,
S::Response: IntoResponse + Send,
S::Error: Send,
S::Future: Send,
F: FnOnce(S::Error) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Res> + Send,
Res: IntoResponse,
B: Send + 'static
// 示例:
use axum::{
Router,
BoxError,
routing::get,
http::StatusCode,
error_handling::HandleErrorLayer,
};
use std::time::Duration;
use tower::ServiceBuilder;
let app = Router::new()
.route("/", get(|| async {}))
.layer(
ServiceBuilder::new()
// `timeout` will produce an error if the handler takes too long so we must handle
// those new 传入的闭包函数的输入是 Error, 返回一个 IntoResponse 对象
.layer(HandleErrorLayer::new(handle_timeout_error)) // 将 Err 转换为 IntoResponse
.timeout(Duration::from_secs(30)) // 可能返回 Err 的 layer handler
//.layer(TimeoutLayer::new(Duration::from_secs(10)))
);
// 闭包函数的输入是 Error, 返回一个 IntoResponse 对象
async fn handle_timeout_error(err: BoxError) -> (StatusCode, String) {
if err.is::<tower::timeout::error::Elapsed>() {
(
StatusCode::REQUEST_TIMEOUT,
"Request took too long".to_string(),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Unhandled internal error: {err}"),
)
}
}
validator 和 garde crate #
说明: validator 已经不再活跃维护,取而代之的是 https://github.com/jprochazk/garde
validator crate 提供 struct 的声明式校验能力:
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Validate, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct ApplicationIn {
#[validate(
length(min = 1, message = "Application names must be at least one character"),
custom = "validate_no_control_characters"
)]
#[schemars(example = "application_name_example")]
pub name: String,
#[validate(range(min = 1, message = "Application rate limits must be at least 1 if set"))]
#[serde(skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<u16>,
/// Optional unique identifier for the application
#[validate]
#[serde(skip_serializing_if = "Option::is_none")]
pub uid: Option<ApplicationUid>,
#[serde(default)]
pub metadata: Metadata,
}
openapi #
相关项目:
- oasgen: https://crates.io/crates/oasgen
- utoipa 和 utoipa-swagger-UI crate 为 auxm 提供了 OpenAPI 支持。
- aide:https://docs.rs/aide/latest/aide/
- aide 的例子:https://github.com/svix/svix-webhooks/blob/main/server/svix-server/src/v1/endpoints/admin.rs
- poem_openapi 项目:https://docs.rs/poem-openapi/latest/poem_openapi/index.html
utoipa 例子:https://github.com/tokio-rs/axum/issues/50#issuecomment-2592149658
#[utoipa::path(
get,
path = "/",
responses(
(status = 200, description = "Send a salute from Axum")
)
)]
pub async fn hello_axum() -> impl IntoResponse {
(StatusCode::OK, "Hello, Axum")
}
// add use utoipa::OpenApi;
#[derive(OpenApi)]
#[openapi(paths(hello_axum))]
pub struct ApiDoc;
// add use utoipa_swagger_ui::SwaggerUi;
let app = Router::new()
.route("/", get(hello_axum))
.merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", ApiDoc::openapi()));
axum tracing log #
配置项目 Cargo.toml 文件,为 axum 指定 tracing feature,同时引入 tower-http 的 trace feature,后续可以为 axum Router 添加 tracer middleware:
[dependencies]
axum = { path = "../../axum", features = ["tracing"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.5.0", features = ["trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
示例:
// https://github.com/tokio-rs/axum/blob/main/examples/tracing-aka-logging/src/main.rs
use axum::{
body::Bytes,
extract::MatchedPath,
http::{HeaderMap, Request},
response::{Html, Response},
routing::get,
Router,
};
use std::time::Duration;
use tokio::net::TcpListener;
use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
use tracing::{info_span, Span};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() {
// 设置全局 tracing provider
tracing_subscriber::registry()
.with(
// 在运行时如果未指定环境变量 RUST_LOG=debug,则这里指定各 crate 的缺省值
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
// axum logs rejections from built-in extractors with the `axum::rejection`
// target, at `TRACE` level. `axum::rejection=trace` enables showing those events
"example_tracing_aka_logging=debug,tower_http=debug,axum::rejection=trace".into()
}),
)
.with(tracing_subscriber::fmt::layer())
.init();
let app = Router::new()
.route("/", get(handler))
.layer(
// 关键:只有加了 TraceLayer 后才会打印 HTTP 日志。
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
info_span!(
"http_request",
method = ?request.method(),
matched_path,
some_other_field = tracing::field::Empty,
)
})
.on_request(|_request: &Request<_>, _span: &Span| {
// You can use `_span.record("some_other_field", value)` in one of these
// closures to attach a value to the initially empty field in the info_span
// created above.
})
.on_response(|_response: &Response, _latency: Duration, _span: &Span| {
// ...
})
.on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {
// ...
})
.on_eos(
|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {
// ...
},
)
.on_failure(
|_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
// ...
},
),
);
let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
async fn handler() -> Html<&'static str> {
Html("<h1>Hello, World!</h1>")
}
JWT Authentication #
// https://docs.shuttle.rs/examples/axum-jwt-authentication
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::{IntoResponse, Response},
routing::{get, post},
Json, RequestPartsExt, Router,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::fmt::Display;
use std::time::SystemTime;
static KEYS: Lazy<Keys> = Lazy::new(|| {
// note that in production, you will probably want to use a random SHA-256 hash or similar
let secret = "JWT_SECRET".to_string();
Keys::new(secret.as_bytes())
});
#[shuttle_runtime::main]
async fn main() -> shuttle_axum::ShuttleAxum {
let app = Router::new()
.route("/public", get(public))
.route("/private", get(private))
.route("/login", post(login));
Ok(app.into())
}
async fn public() -> &'static str {
// A public endpoint that anyone can access
"Welcome to the public area :)"
}
// Claims 实现了 FromRequestParts,可以从请求中提取认证信息
async fn private(claims: Claims) -> Result<String, AuthError> {
// Send the protected data to the user
Ok(format!(
"Welcome to the protected area :)\nYour data:\n{claims}",
))
}
async fn login(Json(payload): Json<AuthPayload>) -> Result<Json<AuthBody>, AuthError> {
// Check if the user sent the credentials
if payload.client_id.is_empty() || payload.client_secret.is_empty() {
return Err(AuthError::MissingCredentials);
}
// Here you can check the user credentials from a database
if payload.client_id != "foo" || payload.client_secret != "bar" {
return Err(AuthError::WrongCredentials);
}
// add 5 minutes to current unix epoch time as expiry date/time
let exp = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 300;
let claims = Claims {
sub: "[email protected]".to_owned(),
company: "ACME".to_owned(),
// Mandatory expiry time as UTC timestamp - takes unix epoch
exp: usize::try_from(exp).unwrap(),
};
// Create the authorization token
let token = encode(&Header::default(), &claims, &KEYS.encoding)
.map_err(|_| AuthError::TokenCreation)?;
// Send the authorized token
Ok(Json(AuthBody::new(token)))
}
// allow us to print the claim details for the private route
impl Display for Claims {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Email: {}\nCompany: {}", self.sub, self.company)
}
}
// implement a method to create a response type containing the JWT
impl AuthBody {
fn new(access_token: String) -> Self {
Self {
access_token,
token_type: "Bearer".to_string(),
}
}
}
// implement FromRequestParts for Claims (the JWT struct)
// FromRequestParts allows us to use Claims without consuming the request
#[async_trait]
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AuthError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AuthError::InvalidToken)?;
// Decode the user data
let token_data = decode::<Claims>(bearer.token(), &KEYS.decoding, &Validation::default())
.map_err(|_| AuthError::InvalidToken)?;
Ok(token_data.claims)
}
}
// implement IntoResponse for AuthError so we can use it as an Axum response type
impl IntoResponse for AuthError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AuthError::WrongCredentials => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
AuthError::MissingCredentials => (StatusCode::BAD_REQUEST, "Missing credentials"),
AuthError::TokenCreation => (StatusCode::INTERNAL_SERVER_ERROR, "Token creation error"),
AuthError::InvalidToken => (StatusCode::BAD_REQUEST, "Invalid token"),
};
let body = Json(json!({
"error": error_message,
}));
(status, body).into_response()
}
}
// encoding/decoding keys - set in the static `once_cell` above
struct Keys {
encoding: EncodingKey,
decoding: DecodingKey,
}
impl Keys {
fn new(secret: &[u8]) -> Self {
Self {
encoding: EncodingKey::from_secret(secret),
decoding: DecodingKey::from_secret(secret),
}
}
}
// the JWT claim
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
sub: String,
company: String,
exp: usize,
}
// the response that we pass back to HTTP client once successfully authorised
#[derive(Debug, Serialize)]
struct AuthBody {
access_token: String,
token_type: String,
}
// the request type - "client_id" is analogous to a username, client_secret can also be interpreted as a password
#[derive(Debug, Deserialize)]
struct AuthPayload {
client_id: String,
client_secret: String,
}
// error types for auth errors
#[derive(Debug)]
enum AuthError {
WrongCredentials,
MissingCredentials,
TokenCreation,
InvalidToken,
}N
参考 #
-
各种 axum 生态项目链接:https://github.com/tokio-rs/axum/blob/main/ECOSYSTEM.md
-
Writing a Rest HTTP Service with Axum: https://docs.shuttle.rs/templates/tutorials/rest-http-service-with-axum
-
axum 项目模板:清晰的 auxm + sqlx + utoipa::OpenApi Docs 示例
-
https://kerkour.com/rust-web-services-axum-sqlx-postgresql
- axum 项目模板