From 428399951cc2d70a2bf5f0f4e8a84cf7cc142feb Mon Sep 17 00:00:00 2001 From: Paul Z Date: Fri, 21 Apr 2023 23:29:18 +0200 Subject: [PATCH] remember signin --- Cargo.toml | 1 - src/error.rs | 3 + src/lib.rs | 155 +++++++++++++++++++++++++++++++++++---------------- 3 files changed, 109 insertions(+), 50 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8a5cbb9..8ed59ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,6 @@ edition = "2021" [dependencies] axum = "0.6" axum-extra = {version="0.7", features=["cookie", "cookie-private"]} -cookie = "0.17" openidconnect = "3.0" async-trait = "0.1" serde = "1.0" diff --git a/src/error.rs b/src/error.rs index 9cc2e37..2d20c6c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,9 @@ pub enum Error { #[error("json serialization error: {:?}", 0)] Json(#[from] serde_json::Error), + #[error("url parsing error: {:?}", 0)] + UrlParsing(#[from] axum::http::Error), + #[error("csrf token is invalid")] CsrfTokenInvalid, diff --git a/src/lib.rs b/src/lib.rs index 25092be..3cffc4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use async_trait::async_trait; use axum::{ extract::{FromRef, FromRequestParts, Query}, @@ -8,7 +10,6 @@ use axum_extra::extract::{ cookie::{Cookie, SameSite}, PrivateCookieJar, }; -use cookie::time::{Duration, OffsetDateTime}; use error::Error; use openidconnect::{ core::{ @@ -26,7 +27,8 @@ use openidconnect::{ }; use serde::{Deserialize, Serialize}; -pub use cookie::Key; +pub use axum::http::Uri; +pub use axum_extra::extract::cookie::Key; pub mod error; @@ -64,45 +66,47 @@ pub type OidcClient = Client< CoreRevocationErrorResponse, >; -pub struct OidcApplication { - application_base: String, - issuer: IssuerUrl, - client_id: ClientId, - client_secret: Option, +pub type IdToken = openidconnect::IdToken< + AZ, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, +>; + +#[derive(Clone)] +pub struct OidcApplication { + application_base: Uri, scopes: Vec, cookie_key: Key, + client: OidcClient, } -impl OidcApplication { - pub fn new( - application_base: String, +impl OidcApplication { + pub async fn create( + application_base: Uri, issuer: String, client_id: String, client_secret: Option, scopes: Vec, cookie_key: Key, - ) -> Self { - Self { - application_base, - issuer: IssuerUrl::new(issuer).unwrap(), - client_id: ClientId::new(client_id), - client_secret: client_secret.map(ClientSecret::new), - scopes, - cookie_key, - } - } - async fn create_client( - &self, - redirect: String, - ) -> Result, Error> { - let provider_metadata = - CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?; + ) -> Result { + let provider_metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(issuer).unwrap(), + async_http_client, + ) + .await?; let client = OidcClient::::from_provider_metadata( provider_metadata, - self.client_id.clone(), - self.client_secret.clone(), - ) - .set_redirect_uri(RedirectUrl::new(redirect)?); - Ok(client) + ClientId::new(client_id), + client_secret.map(ClientSecret::new), + ); + + Ok(Self { + application_base, + scopes, + cookie_key, + client, + }) } } @@ -118,19 +122,21 @@ impl FromRequestParts for ClaimsExtractor where S: Send + Sync, AC: AdditionalClaims, - OidcApplication: FromRef, + OidcApplication: FromRef, { type Rejection = Error; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let application: OidcApplication = OidcApplication::from_ref(state); - let client = application - .create_client(format!( - "{}/{}", - application.application_base, - parts.uri.path() - )) - .await?; + let application: OidcApplication = OidcApplication::from_ref(state); + + let handler_uri = Uri::builder() + .scheme(application.application_base.scheme().unwrap().clone()) + .authority(application.application_base.authority().unwrap().clone()) + .path_and_query(strip_oidc_from_path(&parts.uri)) + .build()?; + + let mut client = application.client; + client = client.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?); let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key); let login_session = jar.get(LOGIN_COOKIE_NAME); @@ -138,8 +144,20 @@ where .await .ok(); - if let (Some(login_session), Some(Query(query))) = (login_session, query) { + if let Some(login_session) = &login_session { let login_session: LoginSession = serde_json::from_str(login_session.value())?; + if let Some(access_token) = login_session.access_token { + let access_token = IdToken::::from_str(&access_token).unwrap(); + if let Ok(claims) = + access_token.claims(&client.id_token_verifier(), &login_session.nonce) + { + return Ok(Self(claims.clone())); + } + } + } + + if let (Some(login_session), Some(Query(query))) = (login_session, query) { + let mut login_session: LoginSession = serde_json::from_str(login_session.value())?; if login_session.csrf_token.secret() != &query.state { return Err(Error::CsrfTokenInvalid); @@ -148,7 +166,9 @@ where let token_response = client .exchange_code(AuthorizationCode::new(query.code.to_string())) // Set the PKCE code verifier. - .set_pkce_verifier(login_session.pkce_verifier) + .set_pkce_verifier(PkceCodeVerifier::new( + login_session.pkce_verifier.secret().to_string(), + )) .request_async(async_http_client) .await?; @@ -168,7 +188,20 @@ where } } - Ok(Self(claims.clone())) + login_session.access_token = Some(id_token.to_string()); + + let login_session = serde_json::to_string(&login_session)?; + jar = jar.add(create_cookie(login_session)); + + Err(Error::Redirect(( + jar, + Redirect::temporary( + handler_uri + .path_and_query() + .map(|x| x.as_str()) + .unwrap_or(handler_uri.path()), + ), + ))) } else { let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (auth_url, csrf_token, nonce) = { @@ -188,14 +221,10 @@ where nonce, csrf_token, pkce_verifier, + access_token: None, }; let login_session = serde_json::to_string(&login_session)?; - let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session); - cookie.set_same_site(SameSite::Lax); - cookie.set_secure(true); - cookie.set_http_only(true); - cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1)); - jar = jar.add(cookie); + jar = jar.add(create_cookie(login_session)); Err(Error::Redirect(( jar, @@ -205,6 +234,33 @@ where } } +fn create_cookie(login_session: String) -> Cookie<'static> { + let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session); + cookie.set_same_site(SameSite::None); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie +} + +fn strip_oidc_from_path(uri: &Uri) -> String { + let query = uri + .query() + .map(|uri| { + uri.split('&') + .filter(|x| { + !x.starts_with("code") + && !x.starts_with("state") + && !x.starts_with("session_state") + }) + .fold(String::new(), |acc, x| acc + "&" + x) + .chars() + .skip(1) + .collect::() + }) + .unwrap_or_default(); + uri.path().to_string() + &query +} + #[derive(Debug, Deserialize)] struct OidcQuery { code: String, @@ -218,4 +274,5 @@ struct LoginSession { nonce: Nonce, csrf_token: CsrfToken, pkce_verifier: PkceCodeVerifier, + access_token: Option, }