diff --git a/Cargo.toml b/Cargo.toml index f6c53bd..482eeec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,12 @@ keywords = [ "axum", "oidc", "openidconnect", "authentication" ] [dependencies] thiserror = "1.0" -axum-core = "0.4" -axum = { version = "0.7", default-features = false, features = [ "query" ] } +axum-core = "0.5" +axum = { version = "0.8", default-features = false, features = [ "query" ] } tower-service = "0.3" tower-layer = "0.3" tower-sessions = { version = "0.13", default-features = false, features = [ "axum-core" ] } http = "1.1" -async-trait = "0.1" openidconnect = "3.5" serde = "1.0" futures-util = "0.3" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index a1b712e..6d862c7 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" [dependencies] tokio = { version = "1.37", features = ["net", "macros", "rt-multi-thread"] } -axum = "0.7" +axum = { version = "0.8", features = ["macros"] } axum-oidc = { path = "./../.." } tower = "0.4" tower-sessions = "0.13" diff --git a/src/extractor.rs b/src/extractor.rs index 9cd41ed..01635a8 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,9 +1,11 @@ -use std::{borrow::Cow, ops::Deref}; +use std::{borrow::Cow, convert::Infallible, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims, ClearSessionFlag}; -use async_trait::async_trait; use axum::response::Redirect; -use axum_core::{extract::FromRequestParts, response::IntoResponse}; +use axum_core::{ + extract::{FromRequestParts, OptionalFromRequestParts}, + response::IntoResponse, +}; use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; @@ -13,7 +15,6 @@ use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; #[derive(Clone)] pub struct OidcClaims(pub IdTokenClaims); -#[async_trait] impl FromRequestParts for OidcClaims where S: Send + Sync, @@ -30,6 +31,18 @@ where } } +impl OptionalFromRequestParts for OidcClaims +where + S: Send + Sync, + AC: AdditionalClaims, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + impl Deref for OidcClaims { type Target = IdTokenClaims; @@ -53,7 +66,6 @@ where #[derive(Clone)] pub struct OidcAccessToken(pub String); -#[async_trait] impl FromRequestParts for OidcAccessToken where S: Send + Sync, @@ -69,6 +81,17 @@ where } } +impl OptionalFromRequestParts for OidcAccessToken +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result, Self::Rejection> { + Ok(parts.extensions.get::().cloned()) + } +} + impl Deref for OidcAccessToken { type Target = str; @@ -147,7 +170,6 @@ impl OidcRpInitiatedLogout { } } -#[async_trait] impl FromRequestParts for OidcRpInitiatedLogout where S: Send + Sync, @@ -159,10 +181,22 @@ where .extensions .get::>() .cloned() - .ok_or(ExtractorError::Unauthorized)?{ + .ok_or(ExtractorError::Unauthorized)? + { Some(this) => Ok(this), None => Err(ExtractorError::RpInitiatedLogoutNotSupported), - } + } + } +} + +impl OptionalFromRequestParts for OidcRpInitiatedLogout +where + S: Send + Sync, +{ + type Rejection = Infallible; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result, Self::Rejection> { + Ok(parts.extensions.get::>().cloned().flatten()) } }