diff --git a/Cargo.toml b/Cargo.toml index d3c3655..8866016 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,9 @@ serde_json = { version = "1", optional = true } [dev-dependencies] axum = { version = "0.7", default-features = false } +axum-test = "14" +tokio = { version = "1", features = ["full"] } +tokio-test = "0.4" [package.metadata.docs.rs] all-features = true diff --git a/README.md b/README.md index ba911b6..a7a3ee3 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,12 @@ Contributions are always welcome! If you have an idea for a feature or find a bug, let me know. PR's are appreciated, but if it's not a small change, please open an issue first so we're all on the same page! +### Testing + +```sh +cargo +nightly test --all-features +``` + ## License `axum-htmx` is dual-licensed under either diff --git a/src/error.rs b/src/error.rs index d89d096..1c2f782 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,15 @@ use std::{error, fmt}; use axum_core::response::IntoResponse; -use http::{header::InvalidHeaderValue, StatusCode}; +use http::{ + header::{InvalidHeaderValue, MaxSizeReached}, + StatusCode, +}; #[derive(Debug)] pub enum HxError { InvalidHeaderValue(InvalidHeaderValue), + TooManyResponseHeaders(MaxSizeReached), #[cfg(feature = "serde")] #[cfg_attr(feature = "unstable", doc(cfg(feature = "serde")))] @@ -18,6 +22,12 @@ impl From for HxError { } } +impl From for HxError { + fn from(value: MaxSizeReached) -> Self { + Self::TooManyResponseHeaders(value) + } +} + #[cfg(feature = "serde")] #[cfg_attr(feature = "unstable", doc(cfg(feature = "serde")))] impl From for HxError { @@ -30,6 +40,7 @@ impl fmt::Display for HxError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { HxError::InvalidHeaderValue(err) => write!(f, "Invalid header value: {err}"), + HxError::TooManyResponseHeaders(err) => write!(f, "Too many response headers: {err}"), #[cfg(feature = "serde")] HxError::Json(err) => write!(f, "Json: {err}"), } diff --git a/src/responders/vary.rs b/src/responders/vary.rs index 1ee7312..03866da 100644 --- a/src/responders/vary.rs +++ b/src/responders/vary.rs @@ -1,8 +1,13 @@ use axum_core::response::{IntoResponseParts, ResponseParts}; -use http::header::VARY; +use http::header::{HeaderValue, VARY}; use crate::{extractors, headers, HxError}; +const HX_REQUEST: HeaderValue = HeaderValue::from_static(headers::HX_REQUEST); +const HX_TARGET: HeaderValue = HeaderValue::from_static(headers::HX_TARGET); +const HX_TRIGGER: HeaderValue = HeaderValue::from_static(headers::HX_TRIGGER); +const HX_TRIGGER_NAME: HeaderValue = HeaderValue::from_static(headers::HX_TRIGGER_NAME); + /// The `Vary: HX-Request` header. /// /// You may want to add this header to the response if your handler responds differently based on @@ -21,8 +26,7 @@ impl IntoResponseParts for VaryHxRequest { type Error = HxError; fn into_response_parts(self, mut res: ResponseParts) -> Result { - res.headers_mut() - .insert(VARY, headers::HX_REQUEST.try_into()?); + res.headers_mut().try_append(VARY, HX_REQUEST)?; Ok(res) } @@ -50,8 +54,7 @@ impl IntoResponseParts for VaryHxTarget { type Error = HxError; fn into_response_parts(self, mut res: ResponseParts) -> Result { - res.headers_mut() - .insert(VARY, headers::HX_TARGET.try_into()?); + res.headers_mut().try_append(VARY, HX_TARGET)?; Ok(res) } @@ -79,8 +82,7 @@ impl IntoResponseParts for VaryHxTrigger { type Error = HxError; fn into_response_parts(self, mut res: ResponseParts) -> Result { - res.headers_mut() - .insert(VARY, headers::HX_TRIGGER.try_into()?); + res.headers_mut().try_append(VARY, HX_TRIGGER)?; Ok(res) } @@ -108,8 +110,7 @@ impl IntoResponseParts for VaryHxTriggerName { type Error = HxError; fn into_response_parts(self, mut res: ResponseParts) -> Result { - res.headers_mut() - .insert(VARY, headers::HX_TRIGGER_NAME.try_into()?); + res.headers_mut().try_append(VARY, HX_TRIGGER_NAME)?; Ok(res) } @@ -121,3 +122,20 @@ impl extractors::HxTriggerName { VaryHxTriggerName } } + +#[cfg(test)] +mod tests { + use super::*; + use axum::{routing::get, Router}; + use std::collections::hash_set::HashSet; + + #[tokio::test] + async fn multiple_headers() { + let app = Router::new().route("/", get(|| async { (VaryHxRequest, VaryHxTarget, "foo") })); + let server = axum_test::TestServer::new(app).unwrap(); + + let resp = server.get("/").await; + let values: HashSet = resp.iter_headers_by_name("vary").cloned().collect(); + assert_eq!(values, HashSet::from([HX_REQUEST, HX_TARGET])); + } +}