Optimize middleware

This commit is contained in:
imbolc 2024-06-15 14:08:13 +06:00
parent 535b19fff8
commit 693707bdfa
2 changed files with 30 additions and 27 deletions

View file

@ -22,6 +22,7 @@ http = { version = "1.0", default-features = false }
async-trait = "0.1"
axum = "0.7" # TODO: remove
tokio = { version = "1", features = ["sync"] } # TODO: hide behind a feature?
futures = "0.3" # TODO
# Optional dependencies required for the `guards` feature.
tower = { version = "0.4", default-features = false, optional = true }

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use axum::{extract::Request, middleware::Next, response::Response};
use axum_core::response::IntoResponse;
use futures::future::join_all;
use http::{
header::{HeaderValue, VARY},
Extensions,
@ -25,7 +26,7 @@ pub trait Notifier {
}
}
fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()>;
fn insert(extensions: &mut Extensions) -> Receiver<()>;
}
macro_rules! define_notifiers {
@ -39,7 +40,7 @@ macro_rules! define_notifiers {
self.0.take().and_then(Arc::into_inner)
}
fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> {
fn insert(extensions: &mut Extensions) -> Receiver<()> {
let (tx, rx) = oneshot::channel();
if extensions.insert(Self(Some(Arc::new(tx)))).is_some() {
panic!("{}", MIDDLEWARE_DOUBLE_USE);
@ -59,36 +60,37 @@ define_notifiers!(
);
pub async fn vary_middleware(mut request: Request, next: Next) -> Response {
let hx_request_rx = HxRequestExtracted::insert_into_extensions(request.extensions_mut());
let hx_target_rx = HxTargetExtracted::insert_into_extensions(request.extensions_mut());
let hx_trigger_rx = HxTriggerExtracted::insert_into_extensions(request.extensions_mut());
let hx_trigger_name_rx =
HxTriggerNameExtracted::insert_into_extensions(request.extensions_mut());
let exts = request.extensions_mut();
let rx_header = [
(HxRequestExtracted::insert(exts), HX_REQUEST_STR),
(HxTargetExtracted::insert(exts), HX_TARGET_STR),
(HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
];
let mut response = next.run(request).await;
let mut used = Vec::with_capacity(4);
if hx_request_rx.await.is_ok() {
used.push(HX_REQUEST_STR)
}
if hx_target_rx.await.is_ok() {
used.push(HX_TARGET_STR)
}
if hx_trigger_rx.await.is_ok() {
used.push(HX_TRIGGER_STR)
}
if hx_trigger_name_rx.await.is_ok() {
used.push(HX_TRIGGER_NAME_STR)
let used_headers: Vec<_> = join_all(
rx_header
.into_iter()
.map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
)
.await
.into_iter()
.flatten()
.collect();
if used_headers.is_empty() {
return response;
}
if !used.is_empty() {
let value = match HeaderValue::from_str(&used.join(", ")) {
Ok(x) => x,
Err(e) => return HxError::from(e).into_response(),
};
if let Err(e) = response.headers_mut().try_append(VARY, value) {
return HxError::from(e).into_response();
}
let value = match HeaderValue::from_str(&used_headers.join(", ")) {
Ok(x) => x,
Err(e) => return HxError::from(e).into_response(),
};
if let Err(e) = response.headers_mut().try_append(VARY, value) {
return HxError::from(e).into_response();
}
response