From 6528a6f247658b4b76d2046893bc50b11055c083 Mon Sep 17 00:00:00 2001 From: Paul Zinselmeyer Date: Tue, 26 Mar 2024 21:06:50 +0100 Subject: [PATCH] Cleanup of RP-Initiated Logout Added comments Removed unwraps Reworked Session container and middlewares --- README.md | 2 + examples/basic/src/main.rs | 12 +- src/extractor.rs | 35 +++--- src/lib.rs | 48 ++++---- src/middleware.rs | 237 +++++++++++++++++++------------------ 5 files changed, 175 insertions(+), 159 deletions(-) diff --git a/README.md b/README.md index 11a3fe7..9dd061a 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ The extractors will always return a value. The `OidcClaims`-extractor can be used to get the OpenId Conenct Claims. The `OidcAccessToken`-extractor can be used to get the OpenId Connect Access Token. +The `OidcRpInitializedLogout`-extractor can be used to get the rp initialized logout uri. + Your OIDC-Client must be allowed to redirect to **every** subpath of your application base url. # Examples diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 1ece2be..da5165f 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,5 +1,9 @@ use axum::{ - error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, + error_handling::HandleErrorLayer, + http::Uri, + response::{IntoResponse, Redirect}, + routing::get, + Router, }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, @@ -80,5 +84,9 @@ async fn maybe_authenticated( } async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { - logout.with_post_logout_redirect(Uri::from_static("https://google.de")) + let logout_uri = logout + .with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) + .uri() + .unwrap(); + Redirect::temporary(&logout_uri.to_string()) } diff --git a/src/extractor.rs b/src/extractor.rs index 2161ee3..bf477c8 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -2,11 +2,7 @@ use std::{borrow::Cow, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims}; use async_trait::async_trait; -use axum::response::Redirect; -use axum_core::{ - extract::FromRequestParts, - response::{IntoResponse, Response}, -}; +use axum_core::extract::FromRequestParts; use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; @@ -82,10 +78,13 @@ impl Deref for OidcAccessToken { impl AsRef for OidcAccessToken { fn as_ref(&self) -> &str { - self.0.as_str() + &self.0 } } +/// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL +/// +/// This Extractor will only succed when the cached session is valid, [crate::middleware::OidcAuthMiddleware] is loaded and the issuer supports RP-Initialized Logout. #[derive(Clone)] pub struct OidcRpInitiatedLogout { pub(crate) end_session_endpoint: Uri, @@ -96,19 +95,23 @@ pub struct OidcRpInitiatedLogout { } impl OidcRpInitiatedLogout { + /// set uri that the user is redirected to after logout. + /// This uri must be in the allowed by issuer. pub fn with_post_logout_redirect(mut self, uri: Uri) -> Self { self.post_logout_redirect_uri = Some(uri); self } + /// set the state parameter that is appended as a query to the post logout redirect uri. pub fn with_state(mut self, state: String) -> Self { self.state = Some(state); self } - pub fn uri(self) -> Uri { - let mut parts = self.end_session_endpoint.into_parts(); + /// get the uri that the client needs to access for logout + pub fn uri(&self) -> Result { + let mut parts = self.end_session_endpoint.clone().into_parts(); let query = { - let mut query = Vec::with_capacity(4); + let mut query: Vec<(&str, Cow<'_, str>)> = Vec::with_capacity(4); query.push(("id_token_hint", Cow::Borrowed(&self.id_token_hint))); query.push(("client_id", Cow::Borrowed(&self.client_id))); @@ -124,7 +127,7 @@ impl OidcRpInitiatedLogout { query .into_iter() - .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v.as_str()))) + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v))) .collect::>() .join("&") }; @@ -135,11 +138,12 @@ impl OidcRpInitiatedLogout { } None => PathAndQuery::from_maybe_shared(format!("?{}", query)), }; - parts.path_and_query = Some(path_and_query.unwrap()); + parts.path_and_query = Some(path_and_query?); - Uri::from_parts(parts).unwrap() + Ok(Uri::from_parts(parts)?) } } + #[async_trait] impl FromRequestParts for OidcRpInitiatedLogout where @@ -155,10 +159,3 @@ where .ok_or(ExtractorError::Unauthorized) } } - -#[async_trait] -impl IntoResponse for OidcRpInitiatedLogout { - fn into_response(self) -> Response { - Redirect::temporary(&self.uri().to_string()).into_response() - } -} diff --git a/src/lib.rs b/src/lib.rs index 29fa9a4..9eb2551 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ +#![deny(unsafe_code)] +#![deny(clippy::unwrap_used)] +#![deny(warnings)] #![doc = include_str!("../README.md")] -use std::str::FromStr; - use crate::error::Error; use http::Uri; use openidconnect::{ @@ -9,15 +10,15 @@ use openidconnect::{ CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseMode, CoreResponseType, - CoreRevocableToken, CoreRevocationErrorResponse, CoreSubjectIdentifierType, - CoreTokenIntrospectionResponse, CoreTokenType, + CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreRevocableToken, + CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, + CoreTokenType, }, reqwest::async_http_client, - ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, - PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, + AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, + IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, }; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub mod error; mod extractor; @@ -28,7 +29,10 @@ pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLogi const SESSION_KEY: &str = "axum-oidc"; -pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {} +pub trait AdditionalClaims: + openidconnect::AdditionalClaims + Clone + Sync + Send + Serialize + DeserializeOwned +{ +} type OidcTokenResponse = StandardTokenResponse< IdTokenFields< @@ -147,28 +151,24 @@ struct OidcQuery { /// oidc session #[derive(Serialize, Deserialize, Debug)] -struct OidcSession { +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct OidcSession { nonce: Nonce, csrf_token: CsrfToken, pkce_verifier: PkceCodeVerifier, - id_token: Option, - access_token: Option, - refresh_token: Option, + authenticated: Option>, } -impl OidcSession { - pub(crate) fn id_token(&self) -> Option> { - self.id_token - .as_ref() - .map(|x| IdToken::::from_str(x).unwrap()) - } - pub(crate) fn refresh_token(&self) -> Option { - self.refresh_token - .as_ref() - .map(|x| RefreshToken::new(x.to_string())) - } +#[derive(Serialize, Deserialize, Debug)] +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct AuthenticatedSession { + id_token: IdToken, + access_token: AccessToken, + refresh_token: Option, } +/// additional metadata that is discovered on client creation via the +/// `.well-knwon/openid-configuration` endpoint. #[derive(Debug, Clone, Serialize, Deserialize)] struct AdditionalProviderMetadata { end_session_endpoint: Option, diff --git a/src/middleware.rs b/src/middleware.rs index 280aae0..b823e66 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -9,16 +9,16 @@ use axum::{ }; use axum_core::{extract::FromRequestParts, response::Response}; use futures_util::future::BoxFuture; -use http::{uri::PathAndQuery, Request, Uri}; +use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ - core::{CoreAuthenticationFlow, CoreErrorResponseType}, + core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, reqwest::async_http_client, - AccessTokenHash, AuthorizationCode, CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, - PkceCodeVerifier, RedirectUrl, + AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -26,7 +26,8 @@ use openidconnect::{ use crate::{ error::{Error, MiddlewareError}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, - AdditionalClaims, BoxError, OidcClient, OidcQuery, OidcSession, SESSION_KEY, + AdditionalClaims, AuthenticatedSession, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, + SESSION_KEY, }; /// Layer for the [OidcLoginMiddleware]. @@ -121,7 +122,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let login_session: Option = session + let login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -158,26 +159,15 @@ where let claims = id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } + validate_access_token_hash(id_token, token_response.access_token(), claims)?; - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); + login_session.authenticated = Some(AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }); - session.insert(SESSION_KEY, login_session).await.unwrap(); + session.insert(SESSION_KEY, login_session).await?; Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) } else { @@ -198,16 +188,14 @@ where auth.set_pkce_challenge(pkce_challenge).url() }; - let oidc_session = OidcSession { + let oidc_session = OidcSession:: { nonce, csrf_token, pkce_verifier, - id_token: None, - access_token: None, - refresh_token: None, + authenticated: None, }; - session.insert(SESSION_KEY, oidc_session).await.unwrap(); + session.insert(SESSION_KEY, oidc_session).await?; Ok(Redirect::temporary(auth_url.as_str()).into_response()) } @@ -307,7 +295,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let mut login_session: Option = session + let mut login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -320,98 +308,37 @@ where .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); if let Some(login_session) = &mut login_session { - let id_token_claims = login_session.id_token::().and_then(|id_token| { - id_token + let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { + session + .id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) .ok() .cloned() + .map(|claims| (session, claims)) }); - match (id_token_claims, login_session.refresh_token()) { + if let Some((session, claims)) = id_token_claims { // stored id token is valid and can be used - (Some(claims), _) => { - parts.extensions.insert(OidcClaims(claims)); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - if let Some(end_session_endpoint) = oidcclient.end_session_endpoint.clone() - { - parts.extensions.insert(OidcRpInitiatedLogout { - end_session_endpoint, - id_token_hint: login_session.id_token.clone().unwrap(), - client_id: oidcclient.client_id.clone(), - post_logout_redirect_uri: None, - state: None, - }); - } - } - // stored id token is invalid and can't be uses, but we have a refresh token - // and can use it and try to get another id token. - (_, Some(refresh_token)) => { - let mut refresh_request = - oidcclient.client.exchange_refresh_token(&refresh_token); + insert_extensions(&mut parts, claims.clone(), &oidcclient, session); + } else if let Some(refresh_token) = login_session + .authenticated + .as_ref() + .and_then(|x| x.refresh_token.as_ref()) + { + if let Some((claims, authenticated_session)) = + try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? + { + insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); + login_session.authenticated = Some(authenticated_session); + }; - for scope in oidcclient.scopes.iter() { - refresh_request = - refresh_request.add_scope(Scope::new(scope.to_string())); - } + // save refreshed session or delete it when the token couldn't be refreshed + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; - match refresh_request.request_async(async_http_client).await { - Ok(token_response) => { - // Extract the ID token claims after verifying its authenticity and nonce. - let id_token = token_response - .id_token() - .ok_or(MiddlewareError::IdTokenMissing)?; - let claims = id_token.claims( - &oidcclient.client.id_token_verifier(), - &login_session.nonce, - )?; - - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() - { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } - - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); - - parts.extensions.insert(OidcClaims(claims.clone())); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - } - Err(ServerResponse(e)) - if *e.error() == CoreErrorResponseType::InvalidGrant => - { - // Refresh failed, refresh_token most likely expired or - // invalid, the session can be considered lost - login_session.refresh_token = None; - } - Err(err) => { - return Err(err.into()); - } - }; - - let session = parts - .extensions - .get::() - .ok_or(MiddlewareError::SessionNotFound)?; - - session.insert(SESSION_KEY, login_session).await.unwrap(); - } - (None, None) => {} + session.insert(SESSION_KEY, login_session).await?; } } @@ -458,3 +385,85 @@ pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result( + parts: &mut Parts, + claims: IdTokenClaims, + client: &OidcClient, + authenticated_session: &AuthenticatedSession, +) { + parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcAccessToken( + authenticated_session.access_token.secret().to_string(), + )); + if let Some(end_session_endpoint) = &client.end_session_endpoint { + parts.extensions.insert(OidcRpInitiatedLogout { + end_session_endpoint: end_session_endpoint.clone(), + id_token_hint: authenticated_session.id_token.to_string(), + client_id: client.client_id.clone(), + post_logout_redirect_uri: None, + state: None, + }); + } +} + +/// Verify the access token hash to ensure that the access token hasn't been substituted for +/// another user's. +/// Returns `Ok` when access token is valid +fn validate_access_token_hash( + id_token: &IdToken, + access_token: &AccessToken, + claims: &IdTokenClaims, +) -> Result<(), MiddlewareError> { + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(access_token, &id_token.signing_alg()?)?; + if actual_access_token_hash == *expected_access_token_hash { + Ok(()) + } else { + Err(MiddlewareError::AccessTokenHashInvalid) + } + } else { + Ok(()) + } +} + +async fn try_refresh_token( + client: &OidcClient, + refresh_token: &RefreshToken, + nonce: &Nonce, +) -> Result, AuthenticatedSession)>, MiddlewareError> +{ + let mut refresh_request = client.client.exchange_refresh_token(refresh_token); + + for scope in client.scopes.iter() { + refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); + } + + match refresh_request.request_async(async_http_client).await { + Ok(token_response) => { + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(MiddlewareError::IdTokenMissing)?; + let claims = id_token.claims(&client.client.id_token_verifier(), nonce)?; + + validate_access_token_hash(id_token, token_response.access_token(), claims)?; + + let authenticated_session = AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }; + + Ok(Some((claims.clone(), authenticated_session))) + } + Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { + // Refresh failed, refresh_token most likely expired or + // invalid, the session can be considered lost + Ok(None) + } + Err(err) => Err(err.into()), + } +}