mirror of
https://github.com/robertwayne/axum-htmx
synced 2024-11-29 06:34:33 +01:00
Middleware as tower layer
This commit is contained in:
parent
57e3e067b1
commit
386662b0a8
3 changed files with 84 additions and 44 deletions
11
Cargo.toml
11
Cargo.toml
|
@ -15,18 +15,13 @@ default = []
|
|||
unstable = []
|
||||
guards = ["tower", "futures-core", "pin-project-lite"]
|
||||
serde = ["dep:serde", "dep:serde_json"]
|
||||
auto-vary = ["axum", "futures", "tokio"]
|
||||
auto-vary = ["futures", "tokio", "tower"]
|
||||
|
||||
[dependencies]
|
||||
axum-core = "0.4"
|
||||
http = { version = "1.0", default-features = false }
|
||||
async-trait = "0.1"
|
||||
|
||||
# Optional dependencies required for the `auto-vary` feature.
|
||||
axum = { version = "0.7", default-features = false, optional = true }
|
||||
tokio = { version = "1", features = ["sync"], optional = true }
|
||||
futures = { version = "0.3", default-features = false, optional = true }
|
||||
|
||||
# Optional dependencies required for the `guards` feature.
|
||||
tower = { version = "0.4", default-features = false, optional = true }
|
||||
futures-core = { version = "0.3", optional = true }
|
||||
|
@ -36,6 +31,10 @@ pin-project-lite = { version = "0.2", optional = true }
|
|||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1", optional = true }
|
||||
|
||||
# Optional dependencies required for the `auto-vary` feature.
|
||||
tokio = { version = "1", features = ["sync"], optional = true }
|
||||
futures = { version = "0.3", default-features = false, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
axum = { version = "0.7", default-features = false }
|
||||
axum-test = "15"
|
||||
|
|
107
src/auto_vary.rs
107
src/auto_vary.rs
|
@ -1,13 +1,19 @@
|
|||
use std::sync::Arc;
|
||||
use std::{
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use axum_core::response::IntoResponse;
|
||||
use futures::future::join_all;
|
||||
use axum_core::{
|
||||
extract::Request,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures::future::{join_all, BoxFuture};
|
||||
use http::{
|
||||
header::{HeaderValue, VARY},
|
||||
Extensions,
|
||||
};
|
||||
use tokio::sync::oneshot::{self, Receiver, Sender};
|
||||
use tower::{Layer, Service};
|
||||
|
||||
use crate::{
|
||||
headers::{HX_REQUEST_STR, HX_TARGET_STR, HX_TRIGGER_NAME_STR, HX_TRIGGER_STR},
|
||||
|
@ -59,41 +65,72 @@ define_notifiers!(
|
|||
HxTriggerNameExtracted
|
||||
);
|
||||
|
||||
pub async fn middleware(mut request: Request, next: Next) -> Response {
|
||||
let exts = request.extensions_mut();
|
||||
let rx_header = [
|
||||
(HxRequestExtracted::insert(exts), HX_REQUEST_STR),
|
||||
(HxTargetExtracted::insert(exts), HX_TARGET_STR),
|
||||
(HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
|
||||
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
|
||||
];
|
||||
#[derive(Default, Clone)]
|
||||
pub struct AutoVaryLayer;
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
impl<S> Layer<S> for AutoVaryLayer {
|
||||
type Service = AutoVaryMiddleware<S>;
|
||||
|
||||
let used_headers: Vec<_> = join_all(
|
||||
rx_header
|
||||
fn layer(&self, inner: S) -> Self::Service {
|
||||
AutoVaryMiddleware { inner }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AutoVaryMiddleware<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> Service<Request> for AutoVaryMiddleware<S>
|
||||
where
|
||||
S: Service<Request, Response = Response> + Send + 'static,
|
||||
S::Future: Send + 'static,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
fn call(&mut self, mut request: Request) -> Self::Future {
|
||||
let exts = request.extensions_mut();
|
||||
let rx_header = [
|
||||
(HxRequestExtracted::insert(exts), HX_REQUEST_STR),
|
||||
(HxTargetExtracted::insert(exts), HX_TARGET_STR),
|
||||
(HxTriggerExtracted::insert(exts), HX_TRIGGER_STR),
|
||||
(HxTriggerNameExtracted::insert(exts), HX_TRIGGER_NAME_STR),
|
||||
];
|
||||
let future = self.inner.call(request);
|
||||
Box::pin(async move {
|
||||
let mut response: Response = future.await?;
|
||||
let used_headers: Vec<_> = join_all(
|
||||
rx_header
|
||||
.into_iter()
|
||||
.map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
|
||||
)
|
||||
.await
|
||||
.into_iter()
|
||||
.map(|(rx, header)| async move { rx.await.ok().map(|_| header) }),
|
||||
)
|
||||
.await
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect();
|
||||
.flatten()
|
||||
.collect();
|
||||
|
||||
if used_headers.is_empty() {
|
||||
return response;
|
||||
if used_headers.is_empty() {
|
||||
return Ok(response);
|
||||
}
|
||||
|
||||
let value = match HeaderValue::from_str(&used_headers.join(", ")) {
|
||||
Ok(x) => x,
|
||||
Err(e) => return Ok(HxError::from(e).into_response()),
|
||||
};
|
||||
|
||||
if let Err(e) = response.headers_mut().try_append(VARY, value) {
|
||||
return Ok(HxError::from(e).into_response());
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
||||
|
||||
let value = match HeaderValue::from_str(&used_headers.join(", ")) {
|
||||
Ok(x) => x,
|
||||
Err(e) => return HxError::from(e).into_response(),
|
||||
};
|
||||
|
||||
if let Err(e) = response.headers_mut().try_append(VARY, value) {
|
||||
return HxError::from(e).into_response();
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -122,7 +159,7 @@ mod tests {
|
|||
"/multiple-extractors",
|
||||
get(|_: HxRequest, _: HxTarget, _: HxTrigger, _: HxTriggerName| async { () }),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(middleware));
|
||||
.layer(AutoVaryLayer);
|
||||
axum_test::TestServer::new(app).unwrap()
|
||||
}
|
||||
|
||||
|
|
10
src/lib.rs
10
src/lib.rs
|
@ -5,6 +5,9 @@
|
|||
mod error;
|
||||
pub use error::*;
|
||||
|
||||
#[cfg(feature = "auto-vary")]
|
||||
#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))]
|
||||
mod auto_vary;
|
||||
pub mod extractors;
|
||||
#[cfg(feature = "guards")]
|
||||
#[cfg_attr(feature = "unstable", doc(cfg(feature = "guards")))]
|
||||
|
@ -12,6 +15,10 @@ pub mod guard;
|
|||
pub mod headers;
|
||||
pub mod responders;
|
||||
|
||||
#[cfg(feature = "auto-vary")]
|
||||
#[cfg_attr(feature = "unstable", doc(cfg(feature = "auto-vary")))]
|
||||
#[doc(inline)]
|
||||
pub use auto_vary::AutoVaryLayer;
|
||||
#[doc(inline)]
|
||||
pub use extractors::*;
|
||||
#[cfg(feature = "guards")]
|
||||
|
@ -22,6 +29,3 @@ pub use guard::*;
|
|||
pub use headers::*;
|
||||
#[doc(inline)]
|
||||
pub use responders::*;
|
||||
|
||||
#[cfg(feature = "auto-vary")]
|
||||
pub mod auto_vary;
|
||||
|
|
Loading…
Reference in a new issue