diff --git a/Cargo.toml b/Cargo.toml index b59509a..34cfd32 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ http = { version = "1.0", default-features = false } async-trait = "0.1" axum = "0.7" # TODO: remove tokio = { version = "1", features = ["sync"] } # TODO: hide behind a feature? +futures = "0.3" # TODO # Optional dependencies required for the `guards` feature. tower = { version = "0.4", default-features = false, optional = true } diff --git a/src/vary_middleware.rs b/src/vary_middleware.rs index 2146643..84c614b 100644 --- a/src/vary_middleware.rs +++ b/src/vary_middleware.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use axum::{extract::Request, middleware::Next, response::Response}; use axum_core::response::IntoResponse; +use futures::future::join_all; use http::{ header::{HeaderValue, VARY}, Extensions, @@ -25,7 +26,7 @@ pub trait Notifier { } } - fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()>; + fn insert(extensions: &mut Extensions) -> Receiver<()>; } macro_rules! define_notifiers { @@ -39,7 +40,7 @@ macro_rules! define_notifiers { self.0.take().and_then(Arc::into_inner) } - fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { + fn insert(extensions: &mut Extensions) -> Receiver<()> { let (tx, rx) = oneshot::channel(); if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { panic!("{}", MIDDLEWARE_DOUBLE_USE); @@ -59,36 +60,37 @@ define_notifiers!( ); pub async fn vary_middleware(mut request: Request, next: Next) -> Response { - let hx_request_rx = HxRequestExtracted::insert_into_extensions(request.extensions_mut()); - let hx_target_rx = HxTargetExtracted::insert_into_extensions(request.extensions_mut()); - let hx_trigger_rx = HxTriggerExtracted::insert_into_extensions(request.extensions_mut()); - let hx_trigger_name_rx = - HxTriggerNameExtracted::insert_into_extensions(request.extensions_mut()); + let exts = request.extensions_mut(); + let rx_header = [ + (HxRequestExtracted::insert(exts), HX_REQUEST_STR), + (HxTargetExtracted::insert(exts), HX_TARGET_STR), + (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR), + (HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR), + ]; let mut response = next.run(request).await; - let mut used = Vec::with_capacity(4); - if hx_request_rx.await.is_ok() { - used.push(HX_REQUEST_STR) - } - if hx_target_rx.await.is_ok() { - used.push(HX_TARGET_STR) - } - if hx_trigger_rx.await.is_ok() { - used.push(HX_TRIGGER_STR) - } - if hx_trigger_name_rx.await.is_ok() { - used.push(HX_TRIGGER_NAME_STR) + let used_headers: Vec<_> = join_all( + rx_header + .into_iter() + .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }), + ) + .await + .into_iter() + .flatten() + .collect(); + + if used_headers.is_empty() { + return response; } - if !used.is_empty() { - let value = match HeaderValue::from_str(&used.join(", ")) { - Ok(x) => x, - Err(e) => return HxError::from(e).into_response(), - }; - if let Err(e) = response.headers_mut().try_append(VARY, value) { - return HxError::from(e).into_response(); - } + let value = match HeaderValue::from_str(&used_headers.join(", ")) { + Ok(x) => x, + Err(e) => return HxError::from(e).into_response(), + }; + + if let Err(e) = response.headers_mut().try_append(VARY, value) { + return HxError::from(e).into_response(); } response