diff --git a/.gitignore b/.gitignore index a9d37c5..e08f5fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.env diff --git a/Cargo.toml b/Cargo.toml index 0e19f01..fbcbb81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ openidconnect = "3.5" serde = "1.0" futures-util = "0.3" reqwest = { version = "0.11", default-features = false } +urlencoding = "2.1.3" diff --git a/README.md b/README.md index 11a3fe7..9dd061a 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ The extractors will always return a value. The `OidcClaims`-extractor can be used to get the OpenId Conenct Claims. The `OidcAccessToken`-extractor can be used to get the OpenId Connect Access Token. +The `OidcRpInitializedLogout`-extractor can be used to get the rp initialized logout uri. + Your OIDC-Client must be allowed to redirect to **every** subpath of your application base url. # Examples diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index bb54e8d..fcd75e1 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -11,3 +11,5 @@ axum = "0.7.4" axum-oidc = { path = "./../.." } tower = "0.4.13" tower-sessions = "0.11.0" + +dotenvy = "0.15.7" diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 4da1acf..da5165f 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,8 +1,13 @@ use axum::{ - error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, + error_handling::HandleErrorLayer, + http::Uri, + response::{IntoResponse, Redirect}, + routing::get, + Router, }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, + OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -13,6 +18,12 @@ use tower_sessions::{ #[tokio::main] async fn main() { + dotenvy::dotenv().ok(); + let app_url = std::env::var("APP_URL").expect("APP_URL env variable"); + let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); + let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); + let client_secret = std::env::var("CLIENT_SECRET").ok(); + let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) @@ -31,10 +42,10 @@ async fn main() { })) .layer( OidcAuthLayer::::discover_client( - Uri::from_static("https://app.example.com"), - "https://auth.example.com/auth/realms/example".to_string(), - "my-client".to_string(), - Some("123456".to_owned()), + Uri::from_maybe_shared(app_url).expect("valid APP_URL"), + issuer, + client_id, + client_secret, vec![], ) .await @@ -43,6 +54,7 @@ async fn main() { let app = Router::new() .route("/foo", get(authenticated)) + .route("/logout", get(logout)) .layer(oidc_login_service) .route("/bar", get(maybe_authenticated)) .layer(oidc_auth_service) @@ -70,3 +82,11 @@ async fn maybe_authenticated( "Hello anon!".to_string() } } + +async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { + let logout_uri = logout + .with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) + .uri() + .unwrap(); + Redirect::temporary(&logout_uri.to_string()) +} diff --git a/src/error.rs b/src/error.rs index c91ea66..6d8997e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,6 +10,9 @@ use thiserror::Error; pub enum ExtractorError { #[error("unauthorized")] Unauthorized, + + #[error("rp initiated logout information not found")] + RpInitiatedLogoutInformationNotFound, } #[derive(Debug, Error)] @@ -65,6 +68,9 @@ pub enum Error { #[error("url parsing: {0:?}")] UrlParsing(#[from] openidconnect::url::ParseError), + #[error("invalid end_session_endpoint uri: {0:?}")] + InvalidEndSessionEndpoint(http::uri::InvalidUri), + #[error("discovery: {0:?}")] Discovery(#[from] openidconnect::DiscoveryError>), @@ -77,7 +83,12 @@ pub enum Error { impl IntoResponse for ExtractorError { fn into_response(self) -> axum_core::response::Response { - (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + match self { + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::RpInitiatedLogoutInformationNotFound => { + (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() + } + } } } diff --git a/src/extractor.rs b/src/extractor.rs index 5c18d78..bf477c8 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,9 +1,9 @@ -use std::ops::Deref; +use std::{borrow::Cow, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims}; use async_trait::async_trait; use axum_core::extract::FromRequestParts; -use http::request::Parts; +use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; /// Extractor for the OpenID Connect Claims. @@ -78,6 +78,84 @@ impl Deref for OidcAccessToken { impl AsRef for OidcAccessToken { fn as_ref(&self) -> &str { - self.0.as_str() + &self.0 + } +} + +/// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL +/// +/// This Extractor will only succed when the cached session is valid, [crate::middleware::OidcAuthMiddleware] is loaded and the issuer supports RP-Initialized Logout. +#[derive(Clone)] +pub struct OidcRpInitiatedLogout { + pub(crate) end_session_endpoint: Uri, + pub(crate) id_token_hint: String, + pub(crate) client_id: String, + pub(crate) post_logout_redirect_uri: Option, + pub(crate) state: Option, +} + +impl OidcRpInitiatedLogout { + /// set uri that the user is redirected to after logout. + /// This uri must be in the allowed by issuer. + pub fn with_post_logout_redirect(mut self, uri: Uri) -> Self { + self.post_logout_redirect_uri = Some(uri); + self + } + /// set the state parameter that is appended as a query to the post logout redirect uri. + pub fn with_state(mut self, state: String) -> Self { + self.state = Some(state); + self + } + /// get the uri that the client needs to access for logout + pub fn uri(&self) -> Result { + let mut parts = self.end_session_endpoint.clone().into_parts(); + + let query = { + let mut query: Vec<(&str, Cow<'_, str>)> = Vec::with_capacity(4); + query.push(("id_token_hint", Cow::Borrowed(&self.id_token_hint))); + query.push(("client_id", Cow::Borrowed(&self.client_id))); + + if let Some(post_logout_redirect_uri) = &self.post_logout_redirect_uri { + query.push(( + "post_logout_redirect_uri", + Cow::Owned(post_logout_redirect_uri.to_string()), + )); + } + if let Some(state) = &self.state { + query.push(("state", Cow::Borrowed(state))); + } + + query + .into_iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v))) + .collect::>() + .join("&") + }; + + let path_and_query = match parts.path_and_query { + Some(path_and_query) => { + PathAndQuery::from_maybe_shared(format!("{}?{}", path_and_query.path(), query)) + } + None => PathAndQuery::from_maybe_shared(format!("?{}", query)), + }; + parts.path_and_query = Some(path_and_query?); + + Ok(Uri::from_parts(parts)?) + } +} + +#[async_trait] +impl FromRequestParts for OidcRpInitiatedLogout +where + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) } } diff --git a/src/lib.rs b/src/lib.rs index 8458314..9eb2551 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,32 +1,38 @@ +#![deny(unsafe_code)] +#![deny(clippy::unwrap_used)] +#![deny(warnings)] #![doc = include_str!("../README.md")] -use std::str::FromStr; - use crate::error::Error; use http::Uri; use openidconnect::{ core::{ - CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, - CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken, - CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType, + CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, + CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, + CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreRevocableToken, + CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, + CoreTokenType, }, reqwest::async_http_client, - ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, - PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, + AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, + IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, }; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub mod error; mod extractor; mod middleware; -pub use extractor::{OidcAccessToken, OidcClaims}; +pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; const SESSION_KEY: &str = "axum-oidc"; -pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {} +pub trait AdditionalClaims: + openidconnect::AdditionalClaims + Clone + Sync + Send + Serialize + DeserializeOwned +{ +} type OidcTokenResponse = StandardTokenResponse< IdTokenFields< @@ -66,14 +72,34 @@ type Client = openidconnect::Client< CoreRevocationErrorResponse, >; +type ProviderMetadata = openidconnect::ProviderMetadata< + AdditionalProviderMetadata, + CoreAuthDisplay, + CoreClientAuthMethod, + CoreClaimName, + CoreClaimType, + CoreGrantType, + CoreJweContentEncryptionAlgorithm, + CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreResponseMode, + CoreResponseType, + CoreSubjectIdentifierType, +>; + pub type BoxError = Box; /// OpenID Connect Client #[derive(Clone)] pub struct OidcClient { scopes: Vec, + client_id: String, client: Client, application_base_url: Uri, + end_session_endpoint: Option, } impl OidcClient { @@ -85,17 +111,25 @@ impl OidcClient { scopes: Vec, ) -> Result { let provider_metadata = - CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client) - .await?; + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; let client = Client::from_provider_metadata( provider_metadata, - ClientId::new(client_id), + ClientId::new(client_id.clone()), client_secret.map(ClientSecret::new), ); Ok(Self { scopes, client, + client_id, application_base_url, + end_session_endpoint, }) } } @@ -117,24 +151,26 @@ struct OidcQuery { /// oidc session #[derive(Serialize, Deserialize, Debug)] -struct OidcSession { +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct OidcSession { nonce: Nonce, csrf_token: CsrfToken, pkce_verifier: PkceCodeVerifier, - id_token: Option, - access_token: Option, - refresh_token: Option, + authenticated: Option>, } -impl OidcSession { - pub(crate) fn id_token(&self) -> Option> { - self.id_token - .as_ref() - .map(|x| IdToken::::from_str(x).unwrap()) - } - pub(crate) fn refresh_token(&self) -> Option { - self.refresh_token - .as_ref() - .map(|x| RefreshToken::new(x.to_string())) - } +#[derive(Serialize, Deserialize, Debug)] +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct AuthenticatedSession { + id_token: IdToken, + access_token: AccessToken, + refresh_token: Option, } + +/// additional metadata that is discovered on client creation via the +/// `.well-knwon/openid-configuration` endpoint. +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AdditionalProviderMetadata { + end_session_endpoint: Option, +} +impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {} diff --git a/src/middleware.rs b/src/middleware.rs index 4be9446..b823e66 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -9,24 +9,25 @@ use axum::{ }; use axum_core::{extract::FromRequestParts, response::Response}; use futures_util::future::BoxFuture; -use http::{uri::PathAndQuery, Request, Uri}; +use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ - core::{CoreAuthenticationFlow, CoreErrorResponseType}, + core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, reqwest::async_http_client, - AccessTokenHash, AuthorizationCode, CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, - PkceCodeVerifier, RedirectUrl, + AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; use crate::{ error::{Error, MiddlewareError}, - extractor::{OidcAccessToken, OidcClaims}, - AdditionalClaims, BoxError, OidcClient, OidcQuery, OidcSession, SESSION_KEY, + extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, + AdditionalClaims, AuthenticatedSession, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, + SESSION_KEY, }; /// Layer for the [OidcLoginMiddleware]. @@ -121,7 +122,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let login_session: Option = session + let login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -158,26 +159,15 @@ where let claims = id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } + validate_access_token_hash(id_token, token_response.access_token(), claims)?; - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); + login_session.authenticated = Some(AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }); - session.insert(SESSION_KEY, login_session).await.unwrap(); + session.insert(SESSION_KEY, login_session).await?; Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) } else { @@ -198,16 +188,14 @@ where auth.set_pkce_challenge(pkce_challenge).url() }; - let oidc_session = OidcSession { + let oidc_session = OidcSession:: { nonce, csrf_token, pkce_verifier, - id_token: None, - access_token: None, - refresh_token: None, + authenticated: None, }; - session.insert(SESSION_KEY, oidc_session).await.unwrap(); + session.insert(SESSION_KEY, oidc_session).await?; Ok(Redirect::temporary(auth_url.as_str()).into_response()) } @@ -307,7 +295,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let mut login_session: Option = session + let mut login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -320,88 +308,37 @@ where .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); if let Some(login_session) = &mut login_session { - let id_token_claims = login_session.id_token::().and_then(|id_token| { - id_token + let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { + session + .id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) .ok() .cloned() + .map(|claims| (session, claims)) }); - match (id_token_claims, login_session.refresh_token()) { + if let Some((session, claims)) = id_token_claims { // stored id token is valid and can be used - (Some(claims), _) => { - parts.extensions.insert(OidcClaims(claims)); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - } - // stored id token is invalid and can't be uses, but we have a refresh token - // and can use it and try to get another id token. - (_, Some(refresh_token)) => { - let mut refresh_request = - oidcclient.client.exchange_refresh_token(&refresh_token); + insert_extensions(&mut parts, claims.clone(), &oidcclient, session); + } else if let Some(refresh_token) = login_session + .authenticated + .as_ref() + .and_then(|x| x.refresh_token.as_ref()) + { + if let Some((claims, authenticated_session)) = + try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? + { + insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); + login_session.authenticated = Some(authenticated_session); + }; - for scope in oidcclient.scopes.iter() { - refresh_request = - refresh_request.add_scope(Scope::new(scope.to_string())); - } + // save refreshed session or delete it when the token couldn't be refreshed + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; - match refresh_request.request_async(async_http_client).await { - Ok(token_response) => { - // Extract the ID token claims after verifying its authenticity and nonce. - let id_token = token_response - .id_token() - .ok_or(MiddlewareError::IdTokenMissing)?; - let claims = id_token.claims( - &oidcclient.client.id_token_verifier(), - &login_session.nonce, - )?; - - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() - { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } - - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); - - parts.extensions.insert(OidcClaims(claims.clone())); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - } - Err(ServerResponse(e)) - if *e.error() == CoreErrorResponseType::InvalidGrant => - { - // Refresh failed, refresh_token most likely expired or - // invalid, the session can be considered lost - login_session.refresh_token = None; - } - Err(err) => { - return Err(err.into()); - } - }; - - let session = parts - .extensions - .get::() - .ok_or(MiddlewareError::SessionNotFound)?; - - session.insert(SESSION_KEY, login_session).await.unwrap(); - } - (None, None) => {} + session.insert(SESSION_KEY, login_session).await?; } } @@ -448,3 +385,85 @@ pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result( + parts: &mut Parts, + claims: IdTokenClaims, + client: &OidcClient, + authenticated_session: &AuthenticatedSession, +) { + parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcAccessToken( + authenticated_session.access_token.secret().to_string(), + )); + if let Some(end_session_endpoint) = &client.end_session_endpoint { + parts.extensions.insert(OidcRpInitiatedLogout { + end_session_endpoint: end_session_endpoint.clone(), + id_token_hint: authenticated_session.id_token.to_string(), + client_id: client.client_id.clone(), + post_logout_redirect_uri: None, + state: None, + }); + } +} + +/// Verify the access token hash to ensure that the access token hasn't been substituted for +/// another user's. +/// Returns `Ok` when access token is valid +fn validate_access_token_hash( + id_token: &IdToken, + access_token: &AccessToken, + claims: &IdTokenClaims, +) -> Result<(), MiddlewareError> { + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(access_token, &id_token.signing_alg()?)?; + if actual_access_token_hash == *expected_access_token_hash { + Ok(()) + } else { + Err(MiddlewareError::AccessTokenHashInvalid) + } + } else { + Ok(()) + } +} + +async fn try_refresh_token( + client: &OidcClient, + refresh_token: &RefreshToken, + nonce: &Nonce, +) -> Result, AuthenticatedSession)>, MiddlewareError> +{ + let mut refresh_request = client.client.exchange_refresh_token(refresh_token); + + for scope in client.scopes.iter() { + refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); + } + + match refresh_request.request_async(async_http_client).await { + Ok(token_response) => { + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(MiddlewareError::IdTokenMissing)?; + let claims = id_token.claims(&client.client.id_token_verifier(), nonce)?; + + validate_access_token_hash(id_token, token_response.access_token(), claims)?; + + let authenticated_session = AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }; + + Ok(Some((claims.clone(), authenticated_session))) + } + Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { + // Refresh failed, refresh_token most likely expired or + // invalid, the session can be considered lost + Ok(None) + } + Err(err) => Err(err.into()), + } +}