Optimize middleware

This commit is contained in:
imbolc 2024-06-15 14:08:13 +06:00
parent 535b19fff8
commit 693707bdfa
2 changed files with 30 additions and 27 deletions

View file

@ -22,6 +22,7 @@ http = { version = "1.0", default-features = false }
async-trait = "0.1" async-trait = "0.1"
axum = "0.7" # TODO: remove axum = "0.7" # TODO: remove
tokio = { version = "1", features = ["sync"] } # TODO: hide behind a feature? tokio = { version = "1", features = ["sync"] } # TODO: hide behind a feature?
futures = "0.3" # TODO
# Optional dependencies required for the `guards` feature. # Optional dependencies required for the `guards` feature.
tower = { version = "0.4", default-features = false, optional = true } tower = { version = "0.4", default-features = false, optional = true }

View file

@ -2,6 +2,7 @@ use std::sync::Arc;
use axum::{extract::Request, middleware::Next, response::Response}; use axum::{extract::Request, middleware::Next, response::Response};
use axum_core::response::IntoResponse; use axum_core::response::IntoResponse;
use futures::future::join_all;
use http::{ use http::{
header::{HeaderValue, VARY}, header::{HeaderValue, VARY},
Extensions, Extensions,
@ -25,7 +26,7 @@ pub trait Notifier {
} }
} }
fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()>; fn insert(extensions: &mut Extensions) -> Receiver<()>;
} }
macro_rules! define_notifiers { macro_rules! define_notifiers {
@ -39,7 +40,7 @@ macro_rules! define_notifiers {
self.0.take().and_then(Arc::into_inner) self.0.take().and_then(Arc::into_inner)
} }
fn insert_into_extensions(extensions: &mut Extensions) -> Receiver<()> { fn insert(extensions: &mut Extensions) -> Receiver<()> {
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
if extensions.insert(Self(Some(Arc::new(tx)))).is_some() { if extensions.insert(Self(Some(Arc::new(tx)))).is_some() {
panic!("{}", MIDDLEWARE_DOUBLE_USE); panic!("{}", MIDDLEWARE_DOUBLE_USE);
@ -59,36 +60,37 @@ define_notifiers!(
); );
pub async fn vary_middleware(mut request: Request, next: Next) -> Response { pub async fn vary_middleware(mut request: Request, next: Next) -> Response {
let hx_request_rx = HxRequestExtracted::insert_into_extensions(request.extensions_mut()); let exts = request.extensions_mut();
let hx_target_rx = HxTargetExtracted::insert_into_extensions(request.extensions_mut()); let rx_header = [
let hx_trigger_rx = HxTriggerExtracted::insert_into_extensions(request.extensions_mut()); (HxRequestExtracted::insert(exts), HX_REQUEST_STR),
let hx_trigger_name_rx = (HxTargetExtracted::insert(exts), HX_TARGET_STR),
HxTriggerNameExtracted::insert_into_extensions(request.extensions_mut()); (HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
];
let mut response = next.run(request).await; let mut response = next.run(request).await;
let mut used = Vec::with_capacity(4); let used_headers: Vec<_> = join_all(
if hx_request_rx.await.is_ok() { rx_header
used.push(HX_REQUEST_STR) .into_iter()
} .map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
if hx_target_rx.await.is_ok() { )
used.push(HX_TARGET_STR) .await
} .into_iter()
if hx_trigger_rx.await.is_ok() { .flatten()
used.push(HX_TRIGGER_STR) .collect();
}
if hx_trigger_name_rx.await.is_ok() { if used_headers.is_empty() {
used.push(HX_TRIGGER_NAME_STR) return response;
} }
if !used.is_empty() { let value = match HeaderValue::from_str(&used_headers.join(", ")) {
let value = match HeaderValue::from_str(&used.join(", ")) { Ok(x) => x,
Ok(x) => x, Err(e) => return HxError::from(e).into_response(),
Err(e) => return HxError::from(e).into_response(), };
};
if let Err(e) = response.headers_mut().try_append(VARY, value) { if let Err(e) = response.headers_mut().try_append(VARY, value) {
return HxError::from(e).into_response(); return HxError::from(e).into_response();
}
} }
response response