remember signin

This commit is contained in:
Paul Zinselmeyer 2023-04-21 23:29:18 +02:00
parent 75ed3b861a
commit 428399951c
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
3 changed files with 109 additions and 50 deletions

View file

@ -8,7 +8,6 @@ edition = "2021"
[dependencies] [dependencies]
axum = "0.6" axum = "0.6"
axum-extra = {version="0.7", features=["cookie", "cookie-private"]} axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
cookie = "0.17"
openidconnect = "3.0" openidconnect = "3.0"
async-trait = "0.1" async-trait = "0.1"
serde = "1.0" serde = "1.0"

View file

@ -27,6 +27,9 @@ pub enum Error {
#[error("json serialization error: {:?}", 0)] #[error("json serialization error: {:?}", 0)]
Json(#[from] serde_json::Error), Json(#[from] serde_json::Error),
#[error("url parsing error: {:?}", 0)]
UrlParsing(#[from] axum::http::Error),
#[error("csrf token is invalid")] #[error("csrf token is invalid")]
CsrfTokenInvalid, CsrfTokenInvalid,

View file

@ -1,3 +1,5 @@
use std::str::FromStr;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
extract::{FromRef, FromRequestParts, Query}, extract::{FromRef, FromRequestParts, Query},
@ -8,7 +10,6 @@ use axum_extra::extract::{
cookie::{Cookie, SameSite}, cookie::{Cookie, SameSite},
PrivateCookieJar, PrivateCookieJar,
}; };
use cookie::time::{Duration, OffsetDateTime};
use error::Error; use error::Error;
use openidconnect::{ use openidconnect::{
core::{ core::{
@ -26,7 +27,8 @@ use openidconnect::{
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub use cookie::Key; pub use axum::http::Uri;
pub use axum_extra::extract::cookie::Key;
pub mod error; pub mod error;
@ -64,45 +66,47 @@ pub type OidcClient<AC> = Client<
CoreRevocationErrorResponse, CoreRevocationErrorResponse,
>; >;
pub struct OidcApplication { pub type IdToken<AZ> = openidconnect::IdToken<
application_base: String, AZ,
issuer: IssuerUrl, CoreGenderClaim,
client_id: ClientId, CoreJweContentEncryptionAlgorithm,
client_secret: Option<ClientSecret>, CoreJwsSigningAlgorithm,
CoreJsonWebKeyType,
>;
#[derive(Clone)]
pub struct OidcApplication<AC: AdditionalClaims> {
application_base: Uri,
scopes: Vec<String>, scopes: Vec<String>,
cookie_key: Key, cookie_key: Key,
client: OidcClient<AC>,
} }
impl OidcApplication { impl<AC: AdditionalClaims> OidcApplication<AC> {
pub fn new( pub async fn create(
application_base: String, application_base: Uri,
issuer: String, issuer: String,
client_id: String, client_id: String,
client_secret: Option<String>, client_secret: Option<String>,
scopes: Vec<String>, scopes: Vec<String>,
cookie_key: Key, cookie_key: Key,
) -> Self { ) -> Result<Self, Error> {
Self { let provider_metadata = CoreProviderMetadata::discover_async(
application_base, IssuerUrl::new(issuer).unwrap(),
issuer: IssuerUrl::new(issuer).unwrap(), async_http_client,
client_id: ClientId::new(client_id), )
client_secret: client_secret.map(ClientSecret::new), .await?;
scopes,
cookie_key,
}
}
async fn create_client<AC: AdditionalClaims>(
&self,
redirect: String,
) -> Result<OidcClient<AC>, Error> {
let provider_metadata =
CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?;
let client = OidcClient::<AC>::from_provider_metadata( let client = OidcClient::<AC>::from_provider_metadata(
provider_metadata, provider_metadata,
self.client_id.clone(), ClientId::new(client_id),
self.client_secret.clone(), client_secret.map(ClientSecret::new),
) );
.set_redirect_uri(RedirectUrl::new(redirect)?);
Ok(client) Ok(Self {
application_base,
scopes,
cookie_key,
client,
})
} }
} }
@ -118,19 +122,21 @@ impl<S, AC> FromRequestParts<S> for ClaimsExtractor<AC>
where where
S: Send + Sync, S: Send + Sync,
AC: AdditionalClaims, AC: AdditionalClaims,
OidcApplication: FromRef<S>, OidcApplication<AC>: FromRef<S>,
{ {
type Rejection = Error; type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let application: OidcApplication = OidcApplication::from_ref(state); let application: OidcApplication<AC> = OidcApplication::from_ref(state);
let client = application
.create_client(format!( let handler_uri = Uri::builder()
"{}/{}", .scheme(application.application_base.scheme().unwrap().clone())
application.application_base, .authority(application.application_base.authority().unwrap().clone())
parts.uri.path() .path_and_query(strip_oidc_from_path(&parts.uri))
)) .build()?;
.await?;
let mut client = application.client;
client = client.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?);
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key); let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
let login_session = jar.get(LOGIN_COOKIE_NAME); let login_session = jar.get(LOGIN_COOKIE_NAME);
@ -138,8 +144,20 @@ where
.await .await
.ok(); .ok();
if let (Some(login_session), Some(Query(query))) = (login_session, query) { if let Some(login_session) = &login_session {
let login_session: LoginSession = serde_json::from_str(login_session.value())?; let login_session: LoginSession = serde_json::from_str(login_session.value())?;
if let Some(access_token) = login_session.access_token {
let access_token = IdToken::<AC>::from_str(&access_token).unwrap();
if let Ok(claims) =
access_token.claims(&client.id_token_verifier(), &login_session.nonce)
{
return Ok(Self(claims.clone()));
}
}
}
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
let mut login_session: LoginSession = serde_json::from_str(login_session.value())?;
if login_session.csrf_token.secret() != &query.state { if login_session.csrf_token.secret() != &query.state {
return Err(Error::CsrfTokenInvalid); return Err(Error::CsrfTokenInvalid);
@ -148,7 +166,9 @@ where
let token_response = client let token_response = client
.exchange_code(AuthorizationCode::new(query.code.to_string())) .exchange_code(AuthorizationCode::new(query.code.to_string()))
// Set the PKCE code verifier. // Set the PKCE code verifier.
.set_pkce_verifier(login_session.pkce_verifier) .set_pkce_verifier(PkceCodeVerifier::new(
login_session.pkce_verifier.secret().to_string(),
))
.request_async(async_http_client) .request_async(async_http_client)
.await?; .await?;
@ -168,7 +188,20 @@ where
} }
} }
Ok(Self(claims.clone())) login_session.access_token = Some(id_token.to_string());
let login_session = serde_json::to_string(&login_session)?;
jar = jar.add(create_cookie(login_session));
Err(Error::Redirect((
jar,
Redirect::temporary(
handler_uri
.path_and_query()
.map(|x| x.as_str())
.unwrap_or(handler_uri.path()),
),
)))
} else { } else {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token, nonce) = { let (auth_url, csrf_token, nonce) = {
@ -188,14 +221,10 @@ where
nonce, nonce,
csrf_token, csrf_token,
pkce_verifier, pkce_verifier,
access_token: None,
}; };
let login_session = serde_json::to_string(&login_session)?; let login_session = serde_json::to_string(&login_session)?;
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session); jar = jar.add(create_cookie(login_session));
cookie.set_same_site(SameSite::Lax);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1));
jar = jar.add(cookie);
Err(Error::Redirect(( Err(Error::Redirect((
jar, jar,
@ -205,6 +234,33 @@ where
} }
} }
fn create_cookie(login_session: String) -> Cookie<'static> {
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
cookie.set_same_site(SameSite::None);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie
}
fn strip_oidc_from_path(uri: &Uri) -> String {
let query = uri
.query()
.map(|uri| {
uri.split('&')
.filter(|x| {
!x.starts_with("code")
&& !x.starts_with("state")
&& !x.starts_with("session_state")
})
.fold(String::new(), |acc, x| acc + "&" + x)
.chars()
.skip(1)
.collect::<String>()
})
.unwrap_or_default();
uri.path().to_string() + &query
}
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct OidcQuery { struct OidcQuery {
code: String, code: String,
@ -218,4 +274,5 @@ struct LoginSession {
nonce: Nonce, nonce: Nonce,
csrf_token: CsrfToken, csrf_token: CsrfToken,
pkce_verifier: PkceCodeVerifier, pkce_verifier: PkceCodeVerifier,
access_token: Option<String>,
} }