diff --git a/Cargo.toml b/Cargo.toml index c2c3c59..01fb0eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,18 +15,13 @@ default = [] unstable = [] guards = ["tower", "futures-core", "pin-project-lite"] serde = ["dep:serde", "dep:serde_json"] -auto-vary = ["axum", "futures", "tokio"] +auto-vary = ["futures", "tokio", "tower"] [dependencies] axum-core = "0.4" http = { version = "1.0", default-features = false } async-trait = "0.1" -# Optional dependencies required for the `auto-vary` feature. -axum = { version = "0.7", default-features = false, optional = true } -tokio = { version = "1", features = ["sync"], optional = true } -futures = { version = "0.3", default-features = false, optional = true } - # Optional dependencies required for the `guards` feature. tower = { version = "0.4", default-features = false, optional = true } futures-core = { version = "0.3", optional = true } @@ -36,6 +31,10 @@ pin-project-lite = { version = "0.2", optional = true } serde = { version = "1", features = ["derive"], optional = true } serde_json = { version = "1", optional = true } +# Optional dependencies required for the `auto-vary` feature. +tokio = { version = "1", features = ["sync"], optional = true } +futures = { version = "0.3", default-features = false, optional = true } + [dev-dependencies] axum = { version = "0.7", default-features = false } axum-test = "15" diff --git a/src/auto_vary.rs b/src/auto_vary.rs index fd3b609..c48bc36 100644 --- a/src/auto_vary.rs +++ b/src/auto_vary.rs @@ -1,13 +1,19 @@ -use std::sync::Arc; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; -use axum::{extract::Request, middleware::Next, response::Response}; -use axum_core::response::IntoResponse; -use futures::future::join_all; +use axum_core::{ + extract::Request, + response::{IntoResponse, Response}, +}; +use futures::future::{join_all, BoxFuture}; use http::{ header::{HeaderValue, VARY}, Extensions, }; use tokio::sync::oneshot::{self, Receiver, Sender}; +use tower::{Layer, Service}; use crate::{ headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR}, @@ -59,41 +65,72 @@ define_notifiers!( HxTriggerNameExtracted ); -pub async fn middleware(mut request: Request, next: Next) -> Response { - let exts = request.extensions_mut(); - let rx_header = [ - (HxRequestExtracted::insert(exts), HX_REQUEST_STR), - (HxTargetExtracted::insert(exts), HX_TARGET_STR), - (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR), - (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR), - ]; +#[derive(Default, Clone)] +pub struct AutoVaryLayer; - let mut response = next.run(request).await; +impl Layer for AutoVaryLayer { + type Service = AutoVaryMiddleware; - let used_headers: Vec<_> = join_all( - rx_header + fn layer(&self, inner: S) -> Self::Service { + AutoVaryMiddleware { inner } + } +} + +#[derive(Clone)] +pub struct AutoVaryMiddleware { + inner: S, +} + +impl Service for AutoVaryMiddleware +where + S: Service + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut request: Request) -> Self::Future { + let exts = request.extensions_mut(); + let rx_header = [ + (HxRequestExtracted::insert(exts), HX_REQUEST_STR), + (HxTargetExtracted::insert(exts), HX_TARGET_STR), + (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR), + (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR), + ]; + let future = self.inner.call(request); + Box::pin(async move { + let mut response: Response = future.await?; + let used_headers: Vec<_> = join_all( + rx_header + .into_iter() + .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }), + ) + .await .into_iter() - .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }), - ) - .await - .into_iter() - .flatten() - .collect(); + .flatten() + .collect(); - if used_headers.is_empty() { - return response; + if used_headers.is_empty() { + return Ok(response); + } + + let value = match HeaderValue::from_str(&used_headers.join(", ")) { + Ok(x) => x, + Err(e) => return Ok(HxError::from(e).into_response()), + }; + + if let Err(e) = response.headers_mut().try_append(VARY, value) { + return Ok(HxError::from(e).into_response()); + } + + Ok(response) + }) } - - let value = match HeaderValue::from_str(&used_headers.join(", ")) { - Ok(x) => x, - Err(e) => return HxError::from(e).into_response(), - }; - - if let Err(e) = response.headers_mut().try_append(VARY, value) { - return HxError::from(e).into_response(); - } - - response } #[cfg(test)] @@ -122,7 +159,7 @@ mod tests { "/multiple-extractors", get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }), ) - .layer(axum::middleware::from_fn(middleware)); + .layer(AutoVaryLayer); axum_test::TestServer::new(app).unwrap() } diff --git a/src/lib.rs b/src/lib.rs index c0dc625..7560074 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ mod error; pub use error::*; +#[cfg(feature = "auto-vary")] +#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))] +mod auto_vary; pub mod extractors; #[cfg(feature = "guards")] #[cfg_attr(feature = "unstable", doc(cfg(feature = "guards")))] @@ -12,6 +15,10 @@ pub mod guard; pub mod headers; pub mod responders; +#[cfg(feature = "auto-vary")] +#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))] +#[doc(inline)] +pub use auto_vary::AutoVaryLayer; #[doc(inline)] pub use extractors::*; #[cfg(feature = "guards")] @@ -22,6 +29,3 @@ pub use guard::*; pub use headers::*; #[doc(inline)] pub use responders::*; - -#[cfg(feature = "auto-vary")] -pub mod auto_vary;