From 1844b880c165b474f5f33d5b385fc4bf2f8db688 Mon Sep 17 00:00:00 2001 From: Paul Zinselmeyer Date: Mon, 25 Mar 2024 17:20:44 +0100 Subject: [PATCH 1/2] Added first implementation of RP Initiated Logout Created a new extractor for RP-Initiated-Logout and modified example to use it. --- .gitignore | 1 + Cargo.toml | 1 + examples/basic/Cargo.toml | 2 + examples/basic/src/main.rs | 20 +++++++-- src/error.rs | 13 +++++- src/extractor.rs | 87 ++++++++++++++++++++++++++++++++++++-- src/lib.rs | 52 +++++++++++++++++++---- src/middleware.rs | 12 +++++- 8 files changed, 171 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index a9d37c5..e08f5fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.env diff --git a/Cargo.toml b/Cargo.toml index 0e19f01..fbcbb81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ openidconnect = "3.5" serde = "1.0" futures-util = "0.3" reqwest = { version = "0.11", default-features = false } +urlencoding = "2.1.3" diff --git a/examples/basic/Cargo.toml b/examples/basic/Cargo.toml index bb54e8d..fcd75e1 100644 --- a/examples/basic/Cargo.toml +++ b/examples/basic/Cargo.toml @@ -11,3 +11,5 @@ axum = "0.7.4" axum-oidc = { path = "./../.." } tower = "0.4.13" tower-sessions = "0.11.0" + +dotenvy = "0.15.7" diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 4da1acf..1ece2be 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -3,6 +3,7 @@ use axum::{ }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, + OidcRpInitiatedLogout, }; use tokio::net::TcpListener; use tower::ServiceBuilder; @@ -13,6 +14,12 @@ use tower_sessions::{ #[tokio::main] async fn main() { + dotenvy::dotenv().ok(); + let app_url = std::env::var("APP_URL").expect("APP_URL env variable"); + let issuer = std::env::var("ISSUER").expect("ISSUER env variable"); + let client_id = std::env::var("CLIENT_ID").expect("CLIENT_ID env variable"); + let client_secret = std::env::var("CLIENT_SECRET").ok(); + let session_store = MemoryStore::default(); let session_layer = SessionManagerLayer::new(session_store) .with_secure(false) @@ -31,10 +38,10 @@ async fn main() { })) .layer( OidcAuthLayer::::discover_client( - Uri::from_static("https://app.example.com"), - "https://auth.example.com/auth/realms/example".to_string(), - "my-client".to_string(), - Some("123456".to_owned()), + Uri::from_maybe_shared(app_url).expect("valid APP_URL"), + issuer, + client_id, + client_secret, vec![], ) .await @@ -43,6 +50,7 @@ async fn main() { let app = Router::new() .route("/foo", get(authenticated)) + .route("/logout", get(logout)) .layer(oidc_login_service) .route("/bar", get(maybe_authenticated)) .layer(oidc_auth_service) @@ -70,3 +78,7 @@ async fn maybe_authenticated( "Hello anon!".to_string() } } + +async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { + logout.with_post_logout_redirect(Uri::from_static("https://google.de")) +} diff --git a/src/error.rs b/src/error.rs index c91ea66..6d8997e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -10,6 +10,9 @@ use thiserror::Error; pub enum ExtractorError { #[error("unauthorized")] Unauthorized, + + #[error("rp initiated logout information not found")] + RpInitiatedLogoutInformationNotFound, } #[derive(Debug, Error)] @@ -65,6 +68,9 @@ pub enum Error { #[error("url parsing: {0:?}")] UrlParsing(#[from] openidconnect::url::ParseError), + #[error("invalid end_session_endpoint uri: {0:?}")] + InvalidEndSessionEndpoint(http::uri::InvalidUri), + #[error("discovery: {0:?}")] Discovery(#[from] openidconnect::DiscoveryError>), @@ -77,7 +83,12 @@ pub enum Error { impl IntoResponse for ExtractorError { fn into_response(self) -> axum_core::response::Response { - (StatusCode::UNAUTHORIZED, "unauthorized").into_response() + match self { + Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized").into_response(), + Self::RpInitiatedLogoutInformationNotFound => { + (StatusCode::INTERNAL_SERVER_ERROR, "intenal server error").into_response() + } + } } } diff --git a/src/extractor.rs b/src/extractor.rs index 5c18d78..2161ee3 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -1,9 +1,13 @@ -use std::ops::Deref; +use std::{borrow::Cow, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims}; use async_trait::async_trait; -use axum_core::extract::FromRequestParts; -use http::request::Parts; +use axum::response::Redirect; +use axum_core::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; /// Extractor for the OpenID Connect Claims. @@ -81,3 +85,80 @@ impl AsRef for OidcAccessToken { self.0.as_str() } } + +#[derive(Clone)] +pub struct OidcRpInitiatedLogout { + pub(crate) end_session_endpoint: Uri, + pub(crate) id_token_hint: String, + pub(crate) client_id: String, + pub(crate) post_logout_redirect_uri: Option, + pub(crate) state: Option, +} + +impl OidcRpInitiatedLogout { + pub fn with_post_logout_redirect(mut self, uri: Uri) -> Self { + self.post_logout_redirect_uri = Some(uri); + self + } + pub fn with_state(mut self, state: String) -> Self { + self.state = Some(state); + self + } + pub fn uri(self) -> Uri { + let mut parts = self.end_session_endpoint.into_parts(); + + let query = { + let mut query = Vec::with_capacity(4); + query.push(("id_token_hint", Cow::Borrowed(&self.id_token_hint))); + query.push(("client_id", Cow::Borrowed(&self.client_id))); + + if let Some(post_logout_redirect_uri) = &self.post_logout_redirect_uri { + query.push(( + "post_logout_redirect_uri", + Cow::Owned(post_logout_redirect_uri.to_string()), + )); + } + if let Some(state) = &self.state { + query.push(("state", Cow::Borrowed(state))); + } + + query + .into_iter() + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v.as_str()))) + .collect::>() + .join("&") + }; + + let path_and_query = match parts.path_and_query { + Some(path_and_query) => { + PathAndQuery::from_maybe_shared(format!("{}?{}", path_and_query.path(), query)) + } + None => PathAndQuery::from_maybe_shared(format!("?{}", query)), + }; + parts.path_and_query = Some(path_and_query.unwrap()); + + Uri::from_parts(parts).unwrap() + } +} +#[async_trait] +impl FromRequestParts for OidcRpInitiatedLogout +where + S: Send + Sync, +{ + type Rejection = ExtractorError; + + async fn from_request_parts(parts: &mut Parts, _: &S) -> Result { + parts + .extensions + .get::() + .cloned() + .ok_or(ExtractorError::Unauthorized) + } +} + +#[async_trait] +impl IntoResponse for OidcRpInitiatedLogout { + fn into_response(self) -> Response { + Redirect::temporary(&self.uri().to_string()).into_response() + } +} diff --git a/src/lib.rs b/src/lib.rs index 8458314..29fa9a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,10 +6,12 @@ use crate::error::Error; use http::Uri; use openidconnect::{ core::{ - CoreAuthDisplay, CoreAuthPrompt, CoreErrorResponseType, CoreGenderClaim, CoreJsonWebKey, - CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreRevocableToken, - CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenType, + CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, + CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, + CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseMode, CoreResponseType, + CoreRevocableToken, CoreRevocationErrorResponse, CoreSubjectIdentifierType, + CoreTokenIntrospectionResponse, CoreTokenType, }, reqwest::async_http_client, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, @@ -21,7 +23,7 @@ pub mod error; mod extractor; mod middleware; -pub use extractor::{OidcAccessToken, OidcClaims}; +pub use extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; const SESSION_KEY: &str = "axum-oidc"; @@ -66,14 +68,34 @@ type Client = openidconnect::Client< CoreRevocationErrorResponse, >; +type ProviderMetadata = openidconnect::ProviderMetadata< + AdditionalProviderMetadata, + CoreAuthDisplay, + CoreClientAuthMethod, + CoreClaimName, + CoreClaimType, + CoreGrantType, + CoreJweContentEncryptionAlgorithm, + CoreJweKeyManagementAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreResponseMode, + CoreResponseType, + CoreSubjectIdentifierType, +>; + pub type BoxError = Box; /// OpenID Connect Client #[derive(Clone)] pub struct OidcClient { scopes: Vec, + client_id: String, client: Client, application_base_url: Uri, + end_session_endpoint: Option, } impl OidcClient { @@ -85,17 +107,25 @@ impl OidcClient { scopes: Vec, ) -> Result { let provider_metadata = - CoreProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client) - .await?; + ProviderMetadata::discover_async(IssuerUrl::new(issuer)?, async_http_client).await?; + let end_session_endpoint = provider_metadata + .additional_metadata() + .end_session_endpoint + .clone() + .map(Uri::from_maybe_shared) + .transpose() + .map_err(Error::InvalidEndSessionEndpoint)?; let client = Client::from_provider_metadata( provider_metadata, - ClientId::new(client_id), + ClientId::new(client_id.clone()), client_secret.map(ClientSecret::new), ); Ok(Self { scopes, client, + client_id, application_base_url, + end_session_endpoint, }) } } @@ -138,3 +168,9 @@ impl OidcSession { .map(|x| RefreshToken::new(x.to_string())) } } + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct AdditionalProviderMetadata { + end_session_endpoint: Option, +} +impl openidconnect::AdditionalProviderMetadata for AdditionalProviderMetadata {} diff --git a/src/middleware.rs b/src/middleware.rs index 4be9446..280aae0 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -25,7 +25,7 @@ use openidconnect::{ use crate::{ error::{Error, MiddlewareError}, - extractor::{OidcAccessToken, OidcClaims}, + extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, AdditionalClaims, BoxError, OidcClient, OidcQuery, OidcSession, SESSION_KEY, }; @@ -334,6 +334,16 @@ where parts.extensions.insert(OidcAccessToken( login_session.access_token.clone().unwrap_or_default(), )); + if let Some(end_session_endpoint) = oidcclient.end_session_endpoint.clone() + { + parts.extensions.insert(OidcRpInitiatedLogout { + end_session_endpoint, + id_token_hint: login_session.id_token.clone().unwrap(), + client_id: oidcclient.client_id.clone(), + post_logout_redirect_uri: None, + state: None, + }); + } } // stored id token is invalid and can't be uses, but we have a refresh token // and can use it and try to get another id token. From 6528a6f247658b4b76d2046893bc50b11055c083 Mon Sep 17 00:00:00 2001 From: Paul Zinselmeyer Date: Tue, 26 Mar 2024 21:06:50 +0100 Subject: [PATCH 2/2] Cleanup of RP-Initiated Logout Added comments Removed unwraps Reworked Session container and middlewares --- README.md | 2 + examples/basic/src/main.rs | 12 +- src/extractor.rs | 35 +++--- src/lib.rs | 48 ++++---- src/middleware.rs | 237 +++++++++++++++++++------------------ 5 files changed, 175 insertions(+), 159 deletions(-) diff --git a/README.md b/README.md index 11a3fe7..9dd061a 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,8 @@ The extractors will always return a value. The `OidcClaims`-extractor can be used to get the OpenId Conenct Claims. The `OidcAccessToken`-extractor can be used to get the OpenId Connect Access Token. +The `OidcRpInitializedLogout`-extractor can be used to get the rp initialized logout uri. + Your OIDC-Client must be allowed to redirect to **every** subpath of your application base url. # Examples diff --git a/examples/basic/src/main.rs b/examples/basic/src/main.rs index 1ece2be..da5165f 100644 --- a/examples/basic/src/main.rs +++ b/examples/basic/src/main.rs @@ -1,5 +1,9 @@ use axum::{ - error_handling::HandleErrorLayer, http::Uri, response::IntoResponse, routing::get, Router, + error_handling::HandleErrorLayer, + http::Uri, + response::{IntoResponse, Redirect}, + routing::get, + Router, }; use axum_oidc::{ error::MiddlewareError, EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer, @@ -80,5 +84,9 @@ async fn maybe_authenticated( } async fn logout(logout: OidcRpInitiatedLogout) -> impl IntoResponse { - logout.with_post_logout_redirect(Uri::from_static("https://google.de")) + let logout_uri = logout + .with_post_logout_redirect(Uri::from_static("https://pfzetto.de")) + .uri() + .unwrap(); + Redirect::temporary(&logout_uri.to_string()) } diff --git a/src/extractor.rs b/src/extractor.rs index 2161ee3..bf477c8 100644 --- a/src/extractor.rs +++ b/src/extractor.rs @@ -2,11 +2,7 @@ use std::{borrow::Cow, ops::Deref}; use crate::{error::ExtractorError, AdditionalClaims}; use async_trait::async_trait; -use axum::response::Redirect; -use axum_core::{ - extract::FromRequestParts, - response::{IntoResponse, Response}, -}; +use axum_core::extract::FromRequestParts; use http::{request::Parts, uri::PathAndQuery, Uri}; use openidconnect::{core::CoreGenderClaim, IdTokenClaims}; @@ -82,10 +78,13 @@ impl Deref for OidcAccessToken { impl AsRef for OidcAccessToken { fn as_ref(&self) -> &str { - self.0.as_str() + &self.0 } } +/// Extractor for the [OpenID Connect RP-Initialized Logout](https://openid.net/specs/openid-connect-rpinitiated-1_0.html) URL +/// +/// This Extractor will only succed when the cached session is valid, [crate::middleware::OidcAuthMiddleware] is loaded and the issuer supports RP-Initialized Logout. #[derive(Clone)] pub struct OidcRpInitiatedLogout { pub(crate) end_session_endpoint: Uri, @@ -96,19 +95,23 @@ pub struct OidcRpInitiatedLogout { } impl OidcRpInitiatedLogout { + /// set uri that the user is redirected to after logout. + /// This uri must be in the allowed by issuer. pub fn with_post_logout_redirect(mut self, uri: Uri) -> Self { self.post_logout_redirect_uri = Some(uri); self } + /// set the state parameter that is appended as a query to the post logout redirect uri. pub fn with_state(mut self, state: String) -> Self { self.state = Some(state); self } - pub fn uri(self) -> Uri { - let mut parts = self.end_session_endpoint.into_parts(); + /// get the uri that the client needs to access for logout + pub fn uri(&self) -> Result { + let mut parts = self.end_session_endpoint.clone().into_parts(); let query = { - let mut query = Vec::with_capacity(4); + let mut query: Vec<(&str, Cow<'_, str>)> = Vec::with_capacity(4); query.push(("id_token_hint", Cow::Borrowed(&self.id_token_hint))); query.push(("client_id", Cow::Borrowed(&self.client_id))); @@ -124,7 +127,7 @@ impl OidcRpInitiatedLogout { query .into_iter() - .map(|(k, v)| format!("{}={}", k, urlencoding::encode(v.as_str()))) + .map(|(k, v)| format!("{}={}", k, urlencoding::encode(&v))) .collect::>() .join("&") }; @@ -135,11 +138,12 @@ impl OidcRpInitiatedLogout { } None => PathAndQuery::from_maybe_shared(format!("?{}", query)), }; - parts.path_and_query = Some(path_and_query.unwrap()); + parts.path_and_query = Some(path_and_query?); - Uri::from_parts(parts).unwrap() + Ok(Uri::from_parts(parts)?) } } + #[async_trait] impl FromRequestParts for OidcRpInitiatedLogout where @@ -155,10 +159,3 @@ where .ok_or(ExtractorError::Unauthorized) } } - -#[async_trait] -impl IntoResponse for OidcRpInitiatedLogout { - fn into_response(self) -> Response { - Redirect::temporary(&self.uri().to_string()).into_response() - } -} diff --git a/src/lib.rs b/src/lib.rs index 29fa9a4..9eb2551 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ +#![deny(unsafe_code)] +#![deny(clippy::unwrap_used)] +#![deny(warnings)] #![doc = include_str!("../README.md")] -use std::str::FromStr; - use crate::error::Error; use http::Uri; use openidconnect::{ @@ -9,15 +10,15 @@ use openidconnect::{ CoreAuthDisplay, CoreAuthPrompt, CoreClaimName, CoreClaimType, CoreClientAuthMethod, CoreErrorResponseType, CoreGenderClaim, CoreGrantType, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse, CoreJweContentEncryptionAlgorithm, CoreJweKeyManagementAlgorithm, - CoreJwsSigningAlgorithm, CoreProviderMetadata, CoreResponseMode, CoreResponseType, - CoreRevocableToken, CoreRevocationErrorResponse, CoreSubjectIdentifierType, - CoreTokenIntrospectionResponse, CoreTokenType, + CoreJwsSigningAlgorithm, CoreResponseMode, CoreResponseType, CoreRevocableToken, + CoreRevocationErrorResponse, CoreSubjectIdentifierType, CoreTokenIntrospectionResponse, + CoreTokenType, }, reqwest::async_http_client, - ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, - PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, + AccessToken, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, + IssuerUrl, Nonce, PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, }; -use serde::{Deserialize, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub mod error; mod extractor; @@ -28,7 +29,10 @@ pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLogi const SESSION_KEY: &str = "axum-oidc"; -pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {} +pub trait AdditionalClaims: + openidconnect::AdditionalClaims + Clone + Sync + Send + Serialize + DeserializeOwned +{ +} type OidcTokenResponse = StandardTokenResponse< IdTokenFields< @@ -147,28 +151,24 @@ struct OidcQuery { /// oidc session #[derive(Serialize, Deserialize, Debug)] -struct OidcSession { +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct OidcSession { nonce: Nonce, csrf_token: CsrfToken, pkce_verifier: PkceCodeVerifier, - id_token: Option, - access_token: Option, - refresh_token: Option, + authenticated: Option>, } -impl OidcSession { - pub(crate) fn id_token(&self) -> Option> { - self.id_token - .as_ref() - .map(|x| IdToken::::from_str(x).unwrap()) - } - pub(crate) fn refresh_token(&self) -> Option { - self.refresh_token - .as_ref() - .map(|x| RefreshToken::new(x.to_string())) - } +#[derive(Serialize, Deserialize, Debug)] +#[serde(bound = "AC: Serialize + DeserializeOwned")] +struct AuthenticatedSession { + id_token: IdToken, + access_token: AccessToken, + refresh_token: Option, } +/// additional metadata that is discovered on client creation via the +/// `.well-knwon/openid-configuration` endpoint. #[derive(Debug, Clone, Serialize, Deserialize)] struct AdditionalProviderMetadata { end_session_endpoint: Option, diff --git a/src/middleware.rs b/src/middleware.rs index 280aae0..b823e66 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -9,16 +9,16 @@ use axum::{ }; use axum_core::{extract::FromRequestParts, response::Response}; use futures_util::future::BoxFuture; -use http::{uri::PathAndQuery, Request, Uri}; +use http::{request::Parts, uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ - core::{CoreAuthenticationFlow, CoreErrorResponseType}, + core::{CoreAuthenticationFlow, CoreErrorResponseType, CoreGenderClaim}, reqwest::async_http_client, - AccessTokenHash, AuthorizationCode, CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, - PkceCodeVerifier, RedirectUrl, + AccessToken, AccessTokenHash, AuthorizationCode, CsrfToken, IdTokenClaims, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError::ServerResponse, Scope, TokenResponse, }; @@ -26,7 +26,8 @@ use openidconnect::{ use crate::{ error::{Error, MiddlewareError}, extractor::{OidcAccessToken, OidcClaims, OidcRpInitiatedLogout}, - AdditionalClaims, BoxError, OidcClient, OidcQuery, OidcSession, SESSION_KEY, + AdditionalClaims, AuthenticatedSession, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, + SESSION_KEY, }; /// Layer for the [OidcLoginMiddleware]. @@ -121,7 +122,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let login_session: Option = session + let login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -158,26 +159,15 @@ where let claims = id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } + validate_access_token_hash(id_token, token_response.access_token(), claims)?; - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); + login_session.authenticated = Some(AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }); - session.insert(SESSION_KEY, login_session).await.unwrap(); + session.insert(SESSION_KEY, login_session).await?; Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) } else { @@ -198,16 +188,14 @@ where auth.set_pkce_challenge(pkce_challenge).url() }; - let oidc_session = OidcSession { + let oidc_session = OidcSession:: { nonce, csrf_token, pkce_verifier, - id_token: None, - access_token: None, - refresh_token: None, + authenticated: None, }; - session.insert(SESSION_KEY, oidc_session).await.unwrap(); + session.insert(SESSION_KEY, oidc_session).await?; Ok(Redirect::temporary(auth_url.as_str()).into_response()) } @@ -307,7 +295,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let mut login_session: Option = session + let mut login_session: Option> = session .get(SESSION_KEY) .await .map_err(MiddlewareError::from)?; @@ -320,98 +308,37 @@ where .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); if let Some(login_session) = &mut login_session { - let id_token_claims = login_session.id_token::().and_then(|id_token| { - id_token + let id_token_claims = login_session.authenticated.as_ref().and_then(|session| { + session + .id_token .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) .ok() .cloned() + .map(|claims| (session, claims)) }); - match (id_token_claims, login_session.refresh_token()) { + if let Some((session, claims)) = id_token_claims { // stored id token is valid and can be used - (Some(claims), _) => { - parts.extensions.insert(OidcClaims(claims)); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - if let Some(end_session_endpoint) = oidcclient.end_session_endpoint.clone() - { - parts.extensions.insert(OidcRpInitiatedLogout { - end_session_endpoint, - id_token_hint: login_session.id_token.clone().unwrap(), - client_id: oidcclient.client_id.clone(), - post_logout_redirect_uri: None, - state: None, - }); - } - } - // stored id token is invalid and can't be uses, but we have a refresh token - // and can use it and try to get another id token. - (_, Some(refresh_token)) => { - let mut refresh_request = - oidcclient.client.exchange_refresh_token(&refresh_token); + insert_extensions(&mut parts, claims.clone(), &oidcclient, session); + } else if let Some(refresh_token) = login_session + .authenticated + .as_ref() + .and_then(|x| x.refresh_token.as_ref()) + { + if let Some((claims, authenticated_session)) = + try_refresh_token(&oidcclient, refresh_token, &login_session.nonce).await? + { + insert_extensions(&mut parts, claims, &oidcclient, &authenticated_session); + login_session.authenticated = Some(authenticated_session); + }; - for scope in oidcclient.scopes.iter() { - refresh_request = - refresh_request.add_scope(Scope::new(scope.to_string())); - } + // save refreshed session or delete it when the token couldn't be refreshed + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; - match refresh_request.request_async(async_http_client).await { - Ok(token_response) => { - // Extract the ID token claims after verifying its authenticity and nonce. - let id_token = token_response - .id_token() - .ok_or(MiddlewareError::IdTokenMissing)?; - let claims = id_token.claims( - &oidcclient.client.id_token_verifier(), - &login_session.nonce, - )?; - - // Verify the access token hash to ensure that the access token hasn't been substituted for - // another user's. - if let Some(expected_access_token_hash) = claims.access_token_hash() - { - let actual_access_token_hash = AccessTokenHash::from_token( - token_response.access_token(), - &id_token.signing_alg()?, - )?; - if actual_access_token_hash != *expected_access_token_hash { - return Err(MiddlewareError::AccessTokenHashInvalid); - } - } - - login_session.id_token = Some(id_token.to_string()); - login_session.access_token = - Some(token_response.access_token().secret().to_string()); - login_session.refresh_token = token_response - .refresh_token() - .map(|x| x.secret().to_string()); - - parts.extensions.insert(OidcClaims(claims.clone())); - parts.extensions.insert(OidcAccessToken( - login_session.access_token.clone().unwrap_or_default(), - )); - } - Err(ServerResponse(e)) - if *e.error() == CoreErrorResponseType::InvalidGrant => - { - // Refresh failed, refresh_token most likely expired or - // invalid, the session can be considered lost - login_session.refresh_token = None; - } - Err(err) => { - return Err(err.into()); - } - }; - - let session = parts - .extensions - .get::() - .ok_or(MiddlewareError::SessionNotFound)?; - - session.insert(SESSION_KEY, login_session).await.unwrap(); - } - (None, None) => {} + session.insert(SESSION_KEY, login_session).await?; } } @@ -458,3 +385,85 @@ pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result( + parts: &mut Parts, + claims: IdTokenClaims, + client: &OidcClient, + authenticated_session: &AuthenticatedSession, +) { + parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcAccessToken( + authenticated_session.access_token.secret().to_string(), + )); + if let Some(end_session_endpoint) = &client.end_session_endpoint { + parts.extensions.insert(OidcRpInitiatedLogout { + end_session_endpoint: end_session_endpoint.clone(), + id_token_hint: authenticated_session.id_token.to_string(), + client_id: client.client_id.clone(), + post_logout_redirect_uri: None, + state: None, + }); + } +} + +/// Verify the access token hash to ensure that the access token hasn't been substituted for +/// another user's. +/// Returns `Ok` when access token is valid +fn validate_access_token_hash( + id_token: &IdToken, + access_token: &AccessToken, + claims: &IdTokenClaims, +) -> Result<(), MiddlewareError> { + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = + AccessTokenHash::from_token(access_token, &id_token.signing_alg()?)?; + if actual_access_token_hash == *expected_access_token_hash { + Ok(()) + } else { + Err(MiddlewareError::AccessTokenHashInvalid) + } + } else { + Ok(()) + } +} + +async fn try_refresh_token( + client: &OidcClient, + refresh_token: &RefreshToken, + nonce: &Nonce, +) -> Result, AuthenticatedSession)>, MiddlewareError> +{ + let mut refresh_request = client.client.exchange_refresh_token(refresh_token); + + for scope in client.scopes.iter() { + refresh_request = refresh_request.add_scope(Scope::new(scope.to_string())); + } + + match refresh_request.request_async(async_http_client).await { + Ok(token_response) => { + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(MiddlewareError::IdTokenMissing)?; + let claims = id_token.claims(&client.client.id_token_verifier(), nonce)?; + + validate_access_token_hash(id_token, token_response.access_token(), claims)?; + + let authenticated_session = AuthenticatedSession { + id_token: id_token.clone(), + access_token: token_response.access_token().clone(), + refresh_token: token_response.refresh_token().cloned(), + }; + + Ok(Some((claims.clone(), authenticated_session))) + } + Err(ServerResponse(e)) if *e.error() == CoreErrorResponseType::InvalidGrant => { + // Refresh failed, refresh_token most likely expired or + // invalid, the session can be considered lost + Ok(None) + } + Err(err) => Err(err.into()), + } +}