Middleware as tower layer

This commit is contained in:
imbolc 2024-06-15 17:09:52 +06:00
parent 57e3e067b1
commit 386662b0a8
3 changed files with 84 additions and 44 deletions

View file

@ -15,18 +15,13 @@ default = []
unstable = [] unstable = []
guards = ["tower", "futures-core", "pin-project-lite"] guards = ["tower", "futures-core", "pin-project-lite"]
serde = ["dep:serde", "dep:serde_json"] serde = ["dep:serde", "dep:serde_json"]
auto-vary = ["axum", "futures", "tokio"] auto-vary = ["futures", "tokio", "tower"]
[dependencies] [dependencies]
axum-core = "0.4" axum-core = "0.4"
http = { version = "1.0", default-features = false } http = { version = "1.0", default-features = false }
async-trait = "0.1" async-trait = "0.1"
# Optional dependencies required for the `auto-vary` feature.
axum = { version = "0.7", default-features = false, optional = true }
tokio = { version = "1", features = ["sync"], optional = true }
futures = { version = "0.3", default-features = false, optional = true }
# Optional dependencies required for the `guards` feature. # Optional dependencies required for the `guards` feature.
tower = { version = "0.4", default-features = false, optional = true } tower = { version = "0.4", default-features = false, optional = true }
futures-core = { version = "0.3", optional = true } futures-core = { version = "0.3", optional = true }
@ -36,6 +31,10 @@ pin-project-lite = { version = "0.2", optional = true }
serde = { version = "1", features = ["derive"], optional = true } serde = { version = "1", features = ["derive"], optional = true }
serde_json = { version = "1", optional = true } serde_json = { version = "1", optional = true }
# Optional dependencies required for the `auto-vary` feature.
tokio = { version = "1", features = ["sync"], optional = true }
futures = { version = "0.3", default-features = false, optional = true }
[dev-dependencies] [dev-dependencies]
axum = { version = "0.7", default-features = false } axum = { version = "0.7", default-features = false }
axum-test = "15" axum-test = "15"

View file

@ -1,13 +1,19 @@
use std::sync::Arc; use std::{
sync::Arc,
task::{Context, Poll},
};
use axum::{extract::Request, middleware::Next, response::Response}; use axum_core::{
use axum_core::response::IntoResponse; extract::Request,
use futures::future::join_all; response::{IntoResponse, Response},
};
use futures::future::{join_all, BoxFuture};
use http::{ use http::{
header::{HeaderValue, VARY}, header::{HeaderValue, VARY},
Extensions, Extensions,
}; };
use tokio::sync::oneshot::{self, Receiver, Sender}; use tokio::sync::oneshot::{self, Receiver, Sender};
use tower::{Layer, Service};
use crate::{ use crate::{
headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR}, headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR},
@ -59,41 +65,72 @@ define_notifiers!(
HxTriggerNameExtracted HxTriggerNameExtracted
); );
pub async fn middleware(mut request: Request, next: Next) -> Response { #[derive(Default, Clone)]
let exts = request.extensions_mut(); pub struct AutoVaryLayer;
let rx_header = [
(HxRequestExtracted::insert(exts), HX_REQUEST_STR),
(HxTargetExtracted::insert(exts), HX_TARGET_STR),
(HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
];
let mut response = next.run(request).await; impl<S> Layer<S> for AutoVaryLayer {
type Service = AutoVaryMiddleware<S>;
let used_headers: Vec<_> = join_all( fn layer(&self, inner: S) -> Self::Service {
rx_header AutoVaryMiddleware { inner }
}
}
#[derive(Clone)]
pub struct AutoVaryMiddleware<S> {
inner: S,
}
impl<S> Service<Request> for AutoVaryMiddleware<S>
where
S: Service<Request, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: Request) -> Self::Future {
let exts = request.extensions_mut();
let rx_header = [
(HxRequestExtracted::insert(exts), HX_REQUEST_STR),
(HxTargetExtracted::insert(exts), HX_TARGET_STR),
(HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
];
let future = self.inner.call(request);
Box::pin(async move {
let mut response: Response = future.await?;
let used_headers: Vec<_> = join_all(
rx_header
.into_iter()
.map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
)
.await
.into_iter() .into_iter()
.map(|(rx, header)| async move { rx.await.ok().map(|_| header) }), .flatten()
) .collect();
.await
.into_iter()
.flatten()
.collect();
if used_headers.is_empty() { if used_headers.is_empty() {
return response; return Ok(response);
}
let value = match HeaderValue::from_str(&used_headers.join(", ")) {
Ok(x) => x,
Err(e) => return Ok(HxError::from(e).into_response()),
};
if let Err(e) = response.headers_mut().try_append(VARY, value) {
return Ok(HxError::from(e).into_response());
}
Ok(response)
})
} }
let value = match HeaderValue::from_str(&used_headers.join(", ")) {
Ok(x) => x,
Err(e) => return HxError::from(e).into_response(),
};
if let Err(e) = response.headers_mut().try_append(VARY, value) {
return HxError::from(e).into_response();
}
response
} }
#[cfg(test)] #[cfg(test)]
@ -122,7 +159,7 @@ mod tests {
"/multiple-extractors", "/multiple-extractors",
get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }), get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }),
) )
.layer(axum::middleware::from_fn(middleware)); .layer(AutoVaryLayer);
axum_test::TestServer::new(app).unwrap() axum_test::TestServer::new(app).unwrap()
} }

View file

@ -5,6 +5,9 @@
mod error; mod error;
pub use error::*; pub use error::*;
#[cfg(feature = "auto-vary")]
#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))]
mod auto_vary;
pub mod extractors; pub mod extractors;
#[cfg(feature = "guards")] #[cfg(feature = "guards")]
#[cfg_attr(feature = "unstable", doc(cfg(feature = "guards")))] #[cfg_attr(feature = "unstable", doc(cfg(feature = "guards")))]
@ -12,6 +15,10 @@ pub mod guard;
pub mod headers; pub mod headers;
pub mod responders; pub mod responders;
#[cfg(feature = "auto-vary")]
#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))]
#[doc(inline)]
pub use auto_vary::AutoVaryLayer;
#[doc(inline)] #[doc(inline)]
pub use extractors::*; pub use extractors::*;
#[cfg(feature = "guards")] #[cfg(feature = "guards")]
@ -22,6 +29,3 @@ pub use guard::*;
pub use headers::*; pub use headers::*;
#[doc(inline)] #[doc(inline)]
pub use responders::*; pub use responders::*;
#[cfg(feature = "auto-vary")]
pub mod auto_vary;