remember signin
This commit is contained in:
parent
75ed3b861a
commit
428399951c
3 changed files with 109 additions and 50 deletions
|
@ -8,7 +8,6 @@ edition = "2021"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.6"
|
axum = "0.6"
|
||||||
axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
|
axum-extra = {version="0.7", features=["cookie", "cookie-private"]}
|
||||||
cookie = "0.17"
|
|
||||||
openidconnect = "3.0"
|
openidconnect = "3.0"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
serde = "1.0"
|
serde = "1.0"
|
||||||
|
|
|
@ -27,6 +27,9 @@ pub enum Error {
|
||||||
#[error("json serialization error: {:?}", 0)]
|
#[error("json serialization error: {:?}", 0)]
|
||||||
Json(#[from] serde_json::Error),
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("url parsing error: {:?}", 0)]
|
||||||
|
UrlParsing(#[from] axum::http::Error),
|
||||||
|
|
||||||
#[error("csrf token is invalid")]
|
#[error("csrf token is invalid")]
|
||||||
CsrfTokenInvalid,
|
CsrfTokenInvalid,
|
||||||
|
|
||||||
|
|
155
src/lib.rs
155
src/lib.rs
|
@ -1,3 +1,5 @@
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{FromRef, FromRequestParts, Query},
|
extract::{FromRef, FromRequestParts, Query},
|
||||||
|
@ -8,7 +10,6 @@ use axum_extra::extract::{
|
||||||
cookie::{Cookie, SameSite},
|
cookie::{Cookie, SameSite},
|
||||||
PrivateCookieJar,
|
PrivateCookieJar,
|
||||||
};
|
};
|
||||||
use cookie::time::{Duration, OffsetDateTime};
|
|
||||||
use error::Error;
|
use error::Error;
|
||||||
use openidconnect::{
|
use openidconnect::{
|
||||||
core::{
|
core::{
|
||||||
|
@ -26,7 +27,8 @@ use openidconnect::{
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub use cookie::Key;
|
pub use axum::http::Uri;
|
||||||
|
pub use axum_extra::extract::cookie::Key;
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
|
||||||
|
@ -64,45 +66,47 @@ pub type OidcClient<AC> = Client<
|
||||||
CoreRevocationErrorResponse,
|
CoreRevocationErrorResponse,
|
||||||
>;
|
>;
|
||||||
|
|
||||||
pub struct OidcApplication {
|
pub type IdToken<AZ> = openidconnect::IdToken<
|
||||||
application_base: String,
|
AZ,
|
||||||
issuer: IssuerUrl,
|
CoreGenderClaim,
|
||||||
client_id: ClientId,
|
CoreJweContentEncryptionAlgorithm,
|
||||||
client_secret: Option<ClientSecret>,
|
CoreJwsSigningAlgorithm,
|
||||||
|
CoreJsonWebKeyType,
|
||||||
|
>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct OidcApplication<AC: AdditionalClaims> {
|
||||||
|
application_base: Uri,
|
||||||
scopes: Vec<String>,
|
scopes: Vec<String>,
|
||||||
cookie_key: Key,
|
cookie_key: Key,
|
||||||
|
client: OidcClient<AC>,
|
||||||
}
|
}
|
||||||
impl OidcApplication {
|
impl<AC: AdditionalClaims> OidcApplication<AC> {
|
||||||
pub fn new(
|
pub async fn create(
|
||||||
application_base: String,
|
application_base: Uri,
|
||||||
issuer: String,
|
issuer: String,
|
||||||
client_id: String,
|
client_id: String,
|
||||||
client_secret: Option<String>,
|
client_secret: Option<String>,
|
||||||
scopes: Vec<String>,
|
scopes: Vec<String>,
|
||||||
cookie_key: Key,
|
cookie_key: Key,
|
||||||
) -> Self {
|
) -> Result<Self, Error> {
|
||||||
Self {
|
let provider_metadata = CoreProviderMetadata::discover_async(
|
||||||
application_base,
|
IssuerUrl::new(issuer).unwrap(),
|
||||||
issuer: IssuerUrl::new(issuer).unwrap(),
|
async_http_client,
|
||||||
client_id: ClientId::new(client_id),
|
)
|
||||||
client_secret: client_secret.map(ClientSecret::new),
|
.await?;
|
||||||
scopes,
|
|
||||||
cookie_key,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
async fn create_client<AC: AdditionalClaims>(
|
|
||||||
&self,
|
|
||||||
redirect: String,
|
|
||||||
) -> Result<OidcClient<AC>, Error> {
|
|
||||||
let provider_metadata =
|
|
||||||
CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?;
|
|
||||||
let client = OidcClient::<AC>::from_provider_metadata(
|
let client = OidcClient::<AC>::from_provider_metadata(
|
||||||
provider_metadata,
|
provider_metadata,
|
||||||
self.client_id.clone(),
|
ClientId::new(client_id),
|
||||||
self.client_secret.clone(),
|
client_secret.map(ClientSecret::new),
|
||||||
)
|
);
|
||||||
.set_redirect_uri(RedirectUrl::new(redirect)?);
|
|
||||||
Ok(client)
|
Ok(Self {
|
||||||
|
application_base,
|
||||||
|
scopes,
|
||||||
|
cookie_key,
|
||||||
|
client,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,19 +122,21 @@ impl<S, AC> FromRequestParts<S> for ClaimsExtractor<AC>
|
||||||
where
|
where
|
||||||
S: Send + Sync,
|
S: Send + Sync,
|
||||||
AC: AdditionalClaims,
|
AC: AdditionalClaims,
|
||||||
OidcApplication: FromRef<S>,
|
OidcApplication<AC>: FromRef<S>,
|
||||||
{
|
{
|
||||||
type Rejection = Error;
|
type Rejection = Error;
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let application: OidcApplication = OidcApplication::from_ref(state);
|
let application: OidcApplication<AC> = OidcApplication::from_ref(state);
|
||||||
let client = application
|
|
||||||
.create_client(format!(
|
let handler_uri = Uri::builder()
|
||||||
"{}/{}",
|
.scheme(application.application_base.scheme().unwrap().clone())
|
||||||
application.application_base,
|
.authority(application.application_base.authority().unwrap().clone())
|
||||||
parts.uri.path()
|
.path_and_query(strip_oidc_from_path(&parts.uri))
|
||||||
))
|
.build()?;
|
||||||
.await?;
|
|
||||||
|
let mut client = application.client;
|
||||||
|
client = client.set_redirect_uri(RedirectUrl::new(handler_uri.to_string())?);
|
||||||
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
|
let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key);
|
||||||
|
|
||||||
let login_session = jar.get(LOGIN_COOKIE_NAME);
|
let login_session = jar.get(LOGIN_COOKIE_NAME);
|
||||||
|
@ -138,8 +144,20 @@ where
|
||||||
.await
|
.await
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
|
if let Some(login_session) = &login_session {
|
||||||
let login_session: LoginSession = serde_json::from_str(login_session.value())?;
|
let login_session: LoginSession = serde_json::from_str(login_session.value())?;
|
||||||
|
if let Some(access_token) = login_session.access_token {
|
||||||
|
let access_token = IdToken::<AC>::from_str(&access_token).unwrap();
|
||||||
|
if let Ok(claims) =
|
||||||
|
access_token.claims(&client.id_token_verifier(), &login_session.nonce)
|
||||||
|
{
|
||||||
|
return Ok(Self(claims.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(login_session), Some(Query(query))) = (login_session, query) {
|
||||||
|
let mut login_session: LoginSession = serde_json::from_str(login_session.value())?;
|
||||||
|
|
||||||
if login_session.csrf_token.secret() != &query.state {
|
if login_session.csrf_token.secret() != &query.state {
|
||||||
return Err(Error::CsrfTokenInvalid);
|
return Err(Error::CsrfTokenInvalid);
|
||||||
|
@ -148,7 +166,9 @@ where
|
||||||
let token_response = client
|
let token_response = client
|
||||||
.exchange_code(AuthorizationCode::new(query.code.to_string()))
|
.exchange_code(AuthorizationCode::new(query.code.to_string()))
|
||||||
// Set the PKCE code verifier.
|
// Set the PKCE code verifier.
|
||||||
.set_pkce_verifier(login_session.pkce_verifier)
|
.set_pkce_verifier(PkceCodeVerifier::new(
|
||||||
|
login_session.pkce_verifier.secret().to_string(),
|
||||||
|
))
|
||||||
.request_async(async_http_client)
|
.request_async(async_http_client)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
@ -168,7 +188,20 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self(claims.clone()))
|
login_session.access_token = Some(id_token.to_string());
|
||||||
|
|
||||||
|
let login_session = serde_json::to_string(&login_session)?;
|
||||||
|
jar = jar.add(create_cookie(login_session));
|
||||||
|
|
||||||
|
Err(Error::Redirect((
|
||||||
|
jar,
|
||||||
|
Redirect::temporary(
|
||||||
|
handler_uri
|
||||||
|
.path_and_query()
|
||||||
|
.map(|x| x.as_str())
|
||||||
|
.unwrap_or(handler_uri.path()),
|
||||||
|
),
|
||||||
|
)))
|
||||||
} else {
|
} else {
|
||||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
let (auth_url, csrf_token, nonce) = {
|
let (auth_url, csrf_token, nonce) = {
|
||||||
|
@ -188,14 +221,10 @@ where
|
||||||
nonce,
|
nonce,
|
||||||
csrf_token,
|
csrf_token,
|
||||||
pkce_verifier,
|
pkce_verifier,
|
||||||
|
access_token: None,
|
||||||
};
|
};
|
||||||
let login_session = serde_json::to_string(&login_session)?;
|
let login_session = serde_json::to_string(&login_session)?;
|
||||||
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
|
jar = jar.add(create_cookie(login_session));
|
||||||
cookie.set_same_site(SameSite::Lax);
|
|
||||||
cookie.set_secure(true);
|
|
||||||
cookie.set_http_only(true);
|
|
||||||
cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1));
|
|
||||||
jar = jar.add(cookie);
|
|
||||||
|
|
||||||
Err(Error::Redirect((
|
Err(Error::Redirect((
|
||||||
jar,
|
jar,
|
||||||
|
@ -205,6 +234,33 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_cookie(login_session: String) -> Cookie<'static> {
|
||||||
|
let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session);
|
||||||
|
cookie.set_same_site(SameSite::None);
|
||||||
|
cookie.set_secure(true);
|
||||||
|
cookie.set_http_only(true);
|
||||||
|
cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_oidc_from_path(uri: &Uri) -> String {
|
||||||
|
let query = uri
|
||||||
|
.query()
|
||||||
|
.map(|uri| {
|
||||||
|
uri.split('&')
|
||||||
|
.filter(|x| {
|
||||||
|
!x.starts_with("code")
|
||||||
|
&& !x.starts_with("state")
|
||||||
|
&& !x.starts_with("session_state")
|
||||||
|
})
|
||||||
|
.fold(String::new(), |acc, x| acc + "&" + x)
|
||||||
|
.chars()
|
||||||
|
.skip(1)
|
||||||
|
.collect::<String>()
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
uri.path().to_string() + &query
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct OidcQuery {
|
struct OidcQuery {
|
||||||
code: String,
|
code: String,
|
||||||
|
@ -218,4 +274,5 @@ struct LoginSession {
|
||||||
nonce: Nonce,
|
nonce: Nonce,
|
||||||
csrf_token: CsrfToken,
|
csrf_token: CsrfToken,
|
||||||
pkce_verifier: PkceCodeVerifier,
|
pkce_verifier: PkceCodeVerifier,
|
||||||
|
access_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
Reference in a new issue