diff --git a/src/lib.rs b/src/lib.rs index 609060f..8458314 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![doc = include_str!("../README.md")] +use std::str::FromStr; + use crate::error::Error; use http::Uri; use openidconnect::{ @@ -11,14 +13,13 @@ use openidconnect::{ }, reqwest::async_http_client, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, IdTokenFields, IssuerUrl, Nonce, - PkceCodeVerifier, StandardErrorResponse, StandardTokenResponse, + PkceCodeVerifier, RefreshToken, StandardErrorResponse, StandardTokenResponse, }; use serde::{Deserialize, Serialize}; pub mod error; mod extractor; mod middleware; -mod util; pub use extractor::{OidcAccessToken, OidcClaims}; pub use middleware::{OidcAuthLayer, OidcAuthMiddleware, OidcLoginLayer, OidcLoginMiddleware}; @@ -89,7 +90,7 @@ impl OidcClient { let client = Client::from_provider_metadata( provider_metadata, ClientId::new(client_id), - client_secret.map(|x| ClientSecret::new(x)), + client_secret.map(ClientSecret::new), ); Ok(Self { scopes, @@ -122,4 +123,18 @@ struct OidcSession { pkce_verifier: PkceCodeVerifier, id_token: Option, access_token: Option, + refresh_token: Option, +} + +impl OidcSession { + pub(crate) fn id_token(&self) -> Option> { + self.id_token + .as_ref() + .map(|x| IdToken::::from_str(x).unwrap()) + } + pub(crate) fn refresh_token(&self) -> Option { + self.refresh_token + .as_ref() + .map(|x| RefreshToken::new(x.to_string())) + } } diff --git a/src/middleware.rs b/src/middleware.rs index 70c9054..c9390f2 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -10,21 +10,25 @@ use axum::{ }; use axum_core::{extract::FromRequestParts, response::Response}; use futures_util::future::BoxFuture; -use http::{Request, Uri}; +use http::{uri::PathAndQuery, Request, Uri}; use tower_layer::Layer; use tower_service::Service; use tower_sessions::Session; use openidconnect::{ - core::CoreAuthenticationFlow, reqwest::async_http_client, AccessTokenHash, AuthorizationCode, - CsrfToken, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, - TokenResponse, + core::{ + CoreAuthenticationFlow, CoreGenderClaim, CoreIdTokenFields, CoreJsonWebKeyType, + CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, + }, + reqwest::async_http_client, + AccessTokenHash, AuthorizationCode, CsrfToken, ExtraTokenFields, IdTokenFields, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, + StandardTokenResponse, TokenResponse, TokenType, }; use crate::{ error::{Error, MiddlewareError}, extractor::{OidcAccessToken, OidcClaims}, - util::strip_oidc_from_path, AdditionalClaims, BoxError, IdToken, OidcClient, OidcQuery, OidcSession, SESSION_KEY, }; @@ -93,14 +97,16 @@ where let mut inner = std::mem::replace(&mut self.inner, inner); if request.extensions().get::().is_some() { + // the OidcAuthMiddleware had a valid id token Box::pin(async move { let response: Response = inner .call(request) .await .map_err(|e| MiddlewareError::NextMiddleware(e.into()))?; - return Ok(response); + Ok(response) }) } else { + // no valid id token or refresh token was found and the user has to login Box::pin(async move { let (mut parts, _) = request.into_parts(); @@ -129,6 +135,9 @@ where .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); if let (Some(mut login_session), Some(query)) = (login_session, query) { + // the request has the request headers of the oidc redirect + // parse the headers and exchange the code for a valid token + if login_session.csrf_token.secret() != &query.state { return Err(MiddlewareError::CsrfTokenInvalid); } @@ -165,11 +174,16 @@ where login_session.id_token = Some(id_token.to_string()); login_session.access_token = Some(token_response.access_token().secret().to_string()); + login_session.refresh_token = token_response + .refresh_token() + .map(|x| x.secret().to_string()); session.insert(SESSION_KEY, login_session).unwrap(); Ok(Redirect::temporary(&handler_uri.to_string()).into_response()) } else { + // generate a login url and redirect the user to it + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token, nonce) = { let mut auth = oidcclient.client.authorize_url( @@ -191,6 +205,7 @@ where pkce_verifier, id_token: None, access_token: None, + refresh_token: None, }; session.insert(SESSION_KEY, oidc_session).unwrap(); @@ -293,7 +308,7 @@ where .extensions .get::() .ok_or(MiddlewareError::SessionNotFound)?; - let login_session: Option = + let mut login_session: Option = session.get(SESSION_KEY).map_err(MiddlewareError::from)?; let handler_uri = @@ -303,20 +318,75 @@ where .client .set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); - if let Some(OidcSession { - nonce, - csrf_token: _, - pkce_verifier: _, - id_token: Some(id_token), - access_token, - }) = &login_session - { - let id_token = IdToken::::from_str(&id_token).unwrap(); - if let Ok(claims) = id_token.claims(&oidcclient.client.id_token_verifier(), nonce) { - parts.extensions.insert(OidcClaims(claims.clone())); - parts - .extensions - .insert(OidcAccessToken(access_token.clone().unwrap_or_default())); + if let Some(login_session) = &mut login_session { + let id_token_claims = login_session.id_token::().and_then(|id_token| { + id_token + .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce) + .ok() + .cloned() + }); + + match (id_token_claims, login_session.refresh_token()) { + // stored id token is valid and can be used + (Some(claims), _) => { + parts.extensions.insert(OidcClaims(claims)); + parts.extensions.insert(OidcAccessToken( + login_session.access_token.clone().unwrap_or_default(), + )); + } + // stored id token is invalid and can't be uses, but we have a refresh token + // and can use it and try to get another id token. + (_, Some(refresh_token)) => { + let mut refresh_request = + oidcclient.client.exchange_refresh_token(&refresh_token); + + for scope in oidcclient.scopes.iter() { + refresh_request = + refresh_request.add_scope(Scope::new(scope.to_string())); + } + + let token_response = + refresh_request.request_async(async_http_client).await?; + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response + .id_token() + .ok_or(MiddlewareError::IdTokenMissing)?; + let claims = id_token + .claims(&oidcclient.client.id_token_verifier(), &login_session.nonce)?; + + // Verify the access token hash to ensure that the access token hasn't been substituted for + // another user's. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + &id_token.signing_alg()?, + )?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(MiddlewareError::AccessTokenHashInvalid); + } + } + + login_session.id_token = Some(id_token.to_string()); + login_session.access_token = + Some(token_response.access_token().secret().to_string()); + login_session.refresh_token = token_response + .refresh_token() + .map(|x| x.secret().to_string()); + + parts.extensions.insert(OidcClaims(claims.clone())); + parts.extensions.insert(OidcAccessToken( + login_session.access_token.clone().unwrap_or_default(), + )); + + let session = parts + .extensions + .get::() + .ok_or(MiddlewareError::SessionNotFound)?; + + session.insert(SESSION_KEY, login_session).unwrap(); + } + (None, None) => {} } } @@ -328,7 +398,37 @@ where .await .map_err(|e| MiddlewareError::NextMiddleware(e.into()))? .into_response(); - return Ok(response); + Ok(response) }) } } + +/// Helper function to remove the OpenID Connect authentication response query attributes from a +/// [`Uri`]. +pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result { + let mut base_url = base_url.into_parts(); + + base_url.path_and_query = uri + .path_and_query() + .map(|path_and_query| { + let query = path_and_query + .query() + .and_then(|uri| { + uri.split('&') + .filter(|x| { + !x.starts_with("code") + && !x.starts_with("state") + && !x.starts_with("session_state") + }) + .map(|x| x.to_string()) + .reduce(|acc, x| acc + "&" + &x) + }) + .map(|x| format!("?{x}")) + .unwrap_or_default(); + + PathAndQuery::from_maybe_shared(format!("{}{}", path_and_query.path(), query)) + }) + .transpose()?; + + Ok(Uri::from_parts(base_url)?) +} diff --git a/src/util.rs b/src/util.rs deleted file mode 100644 index e5f8017..0000000 --- a/src/util.rs +++ /dev/null @@ -1,33 +0,0 @@ -use http::{uri::PathAndQuery, Uri}; - -use crate::error::MiddlewareError; - -/// Helper function to remove the OpenID Connect authentication response query attributes from a -/// [`Uri`]. -pub fn strip_oidc_from_path(base_url: Uri, uri: &Uri) -> Result { - let mut base_url = base_url.into_parts(); - - base_url.path_and_query = uri - .path_and_query() - .map(|path_and_query| { - let query = path_and_query - .query() - .and_then(|uri| { - uri.split('&') - .filter(|x| { - !x.starts_with("code") - && !x.starts_with("state") - && !x.starts_with("session_state") - }) - .map(|x| x.to_string()) - .reduce(|acc, x| acc + "&" + &x) - }) - .map(|x| "?" + x) - .unwrap_or_default(); - - PathAndQuery::from_maybe_shared(format!("{}{}", path_and_query.path(), query)) - }) - .transpose()?; - - Ok(Uri::from_parts(base_url)?) -}