diff --git a/.gitignore b/.gitignore index a9d37c5..e08f5fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.env diff --git a/Cargo.toml b/Cargo.toml index 0e19f01..fbcbb81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ openidconnect = "3.5" serde = "1.0" futures-util = "0.3" reqwest = { version = "0.11", default-features = false } +urlencoding = "2.1.3" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index bb54e8d..fcd75e1 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -11,3 +11,5 @@ axum = "0.7.4" axum-oidc = { path = "./../.." } tower = "0.4.13" tower-sessions = "0.11.0" + +dotenvy = "0.15.7" diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 4da1acf..1ece2be 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -3,6 +3,7 @@ use axum::{ }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, + OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -13,6 +14,12 @@ use tower_sessions::{ #[tokio::main] async fn main() { + dotenvy::dotenv().ok(); + let app_url = std::env::var("APP_URL").expect("APP_URL env variable"); + let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); + let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); + let client_secret = std::env::var("CLIENT_SECRET").ok(); + let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) @@ -31,10 +38,10 @@ async fn main() { })) .layer( OidcAuthLayer::::discover_client( - Uri::from_static("https://app.example.com"), - "https://auth.example.com/auth/realms/example".to_string(), - "my-client".to_string(), - Some("123456".to_owned()), + Uri::from_maybe_shared(app_url).expect("valid APP_URL"), + issuer, + client_id, + client_secret, vec![], ) .await @@ -43,6 +50,7 @@ async fn main() { let app = Router::new() .route("/foo", get(authenticated)) + .route("/logout", get(logout)) .layer(oidc_login_service) .route("/bar", get(maybe_authenticated)) .layer(oidc_auth_service) @@ -70,3 +78,7 @@ async fn maybe_authenticated( "Hello anon!".to_string() } } + +async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { + logout.with_post_logout_redirect(Uri::from_static("https://google.de")) +} diff --git a/src/error.rs b/src/error.rs index c91ea66..6d8997e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,6 +10,9 @@ use thiserror::Error; pub enum ExtractorError { #[error("unauthorized")] Unauthorized, + + #[error("rp initiated logout information not found")] + RpInitiatedLogoutInformationNotFound, } #[derive(Debug, Error)] @@ -65,6 +68,9 @@ pub enum Error { #[error("url parsing: {0:?}")] UrlParsing(#[from] openidconnect::url::ParseError), + #[error("invalid end_session_endpoint uri: {0:?}")] + InvalidEndSessionEndpoint(http::uri::InvalidUri), + #[error("discovery: {0:?}")] Discovery(#[from] openidconnect::DiscoveryError>), @@ -77,7 +83,12 @@ pub enum Error { impl IntoResponse for ExtractorError { fn into_response(self) -> axum_core::response::Response { - (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + match self { + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::RpInitiatedLogoutInformationNotFound => { + (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() + } + } } } diff --git a/src/extractor.rs b/src/extractor.rs index 5c18d78..2161ee3 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,9 +1,13 @@ -use std::ops::Deref; +use std::{borrow::Cow, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims}; use async_trait::async_trait; -use axum_core::extract::FromRequestParts; -use http::request::Parts; +use axum::response::Redirect; +use axum_core::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; /// Extractor for the OpenID Connect Claims. @@ -81,3 +85,80 @@ impl AsRef for OidcAccessToken { self.0.as_str() } } + +#[derive(Clone)] +pub struct OidcRpInitiatedLogout { + pub(crate) end_session_endpoint: Uri, + pub(crate) id_token_hint: String, + pub(crate) client_id: String, + pub(crate) post_logout_redirect_uri: Option, + pub(crate) state: Option, +} + +impl OidcRpInitiatedLogout { + pub fn with_post_logout_redirect(mut self, uri: Uri) -> Self { + self.post_logout_redirect_uri = Some(uri); + self + } + pub fn with_state(mut self, state: String) -> Self { + self.state = Some(state); + self + } + pub fn uri(self) -> Uri { + let mut parts = self.end_session_endpoint.into_parts(); + + let query = { + let mut query = Vec::with_capacity(4); + query.push(("id_token_hint", Cow::Borrowed(&self.id_token_hint))); + query.push(("client_id", Cow::Borrowed(&self.client_id))); + + if let Some(post_logout_redirect_uri) = &self.post_logout_redirect_uri { + query.push(( + "post_logout_redirect_uri", + Cow::Owned(post_logout_redirect_uri.to_string()), + )); + } + if let Some(state) = &self.state { + query.push(("state", Cow::Borrowed(state))); + } + + query + .into_iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v.as_str()))) + .collect::>() + .join("&") + }; + + let path_and_query = match parts.path_and_query { + Some(path_and_query) => { + PathAndQuery::from_maybe_shared(format!("{}?{}", path_and_query.path(), query)) + } + None => PathAndQuery::from_maybe_shared(format!("?{}", query)), + }; + parts.path_and_query = Some(path_and_query.unwrap()); + + Uri::from_parts(parts).unwrap() + } +} +#[async_trait] +impl FromRequestParts for OidcRpInitiatedLogout +where + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} + +#[async_trait] +impl IntoResponse for OidcRpInitiatedLogout { + fn into_response(self) -> Response { + Redirect::temporary(&self.uri().to_string()).into_response() + } +} diff --git a/src/lib.rs b/src/lib.rs index 8458314..29fa9a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,12 @@ use crate::error::Error; use http::Uri; use openidconnect::{ core::{ - CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, - CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken, - CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType, + CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, + CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, + CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseMode, CoreResponseType, + CoreRevocableToken, CoreRevocationErrorResponse, CoreSubjectIdentifierType, + CoreTokenIntrospectionResponse, CoreTokenType, }, reqwest::async_http_client, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, @@ -21,7 +23,7 @@ pub mod error; mod extractor; mod middleware; -pub use extractor::{OidcAccessToken, OidcClaims}; +pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; const SESSION_KEY: &str = "axum-oidc"; @@ -66,14 +68,34 @@ type Client = openidconnect::Client< CoreRevocationErrorResponse, >; +type ProviderMetadata = openidconnect::ProviderMetadata< + AdditionalProviderMetadata, + CoreAuthDisplay, + CoreClientAuthMethod, + CoreClaimName, + CoreClaimType, + CoreGrantType, + CoreJweContentEncryptionAlgorithm, + CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreResponseMode, + CoreResponseType, + CoreSubjectIdentifierType, +>; + pub type BoxError = Box; /// OpenID Connect Client #[derive(Clone)] pub struct OidcClient { scopes: Vec, + client_id: String, client: Client, application_base_url: Uri, + end_session_endpoint: Option, } impl OidcClient { @@ -85,17 +107,25 @@ impl OidcClient { scopes: Vec, ) -> Result { let provider_metadata = - CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client) - .await?; + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; let client = Client::from_provider_metadata( provider_metadata, - ClientId::new(client_id), + ClientId::new(client_id.clone()), client_secret.map(ClientSecret::new), ); Ok(Self { scopes, client, + client_id, application_base_url, + end_session_endpoint, }) } } @@ -138,3 +168,9 @@ impl OidcSession { .map(|x| RefreshToken::new(x.to_string())) } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AdditionalProviderMetadata { + end_session_endpoint: Option, +} +impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {} diff --git a/src/middleware.rs b/src/middleware.rs index 4be9446..280aae0 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -25,7 +25,7 @@ use openidconnect::{ use crate::{ error::{Error, MiddlewareError}, - extractor::{OidcAccessToken, OidcClaims}, + extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, AdditionalClaims, BoxError, OidcClient, OidcQuery, OidcSession, SESSION_KEY, }; @@ -334,6 +334,16 @@ where parts.extensions.insert(OidcAccessToken( login_session.access_token.clone().unwrap_or_default(), )); + if let Some(end_session_endpoint) = oidcclient.end_session_endpoint.clone() + { + parts.extensions.insert(OidcRpInitiatedLogout { + end_session_endpoint, + id_token_hint: login_session.id_token.clone().unwrap(), + client_id: oidcclient.client_id.clone(), + post_logout_redirect_uri: None, + state: None, + }); + } } // stored id token is invalid and can't be uses, but we have a refresh token // and can use it and try to get another id token.