commit 75ed3b861a85cd18d8473a65cc6a90bc38528527 Author: Paul Z Date: Fri Apr 21 15:11:37 2023 +0200 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8a5cbb9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "axum_oidc" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +axum = "0.6" +axum-extra = {version="0.7", features=["cookie", "cookie-private"]} +cookie = "0.17" +openidconnect = "3.0" +async-trait = "0.1" +serde = "1.0" +thiserror = "1.0" +reqwest = { version="0.11", default_features=false} +serde_json = "1.0" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..9cc2e37 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,54 @@ +use axum::response::{IntoResponse, Redirect}; +use axum_extra::extract::PrivateCookieJar; +use openidconnect::{ + core::CoreErrorResponseType, url::ParseError, ClaimsVerificationError, DiscoveryError, + SigningError, StandardErrorResponse, +}; +use reqwest::StatusCode; + +type RequestTokenError = openidconnect::RequestTokenError< + openidconnect::reqwest::Error, + StandardErrorResponse, +>; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("discovery error: {:?}", 0)] + Discovery(#[from] DiscoveryError>), + #[error("parse error: {:?}", 0)] + Parse(#[from] ParseError), + #[error("request token error: {:?}", 0)] + RequestToken(#[from] RequestTokenError), + #[error("claims verification error: {:?}", 0)] + ClaimsVerification(#[from] ClaimsVerificationError), + #[error("signing error: {:?}", 0)] + Signing(#[from] SigningError), + + #[error("json serialization error: {:?}", 0)] + Json(#[from] serde_json::Error), + + #[error("csrf token is invalid")] + CsrfTokenInvalid, + + #[error("id token not found")] + IdTokenNotFound, + + #[error("access token hash is invalid")] + AccessTokenHashInvalid, + + #[error("just a redirect")] + Redirect((PrivateCookieJar, Redirect)), +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + match self { + Self::CsrfTokenInvalid => { + { (StatusCode::BAD_REQUEST, "csrf token is invalid").into_response() } + .into_response() + } + Self::Redirect(redirect) => redirect.into_response(), + _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..25092be --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,221 @@ +use async_trait::async_trait; +use axum::{ + extract::{FromRef, FromRequestParts, Query}, + http::request::Parts, + response::Redirect, +}; +use axum_extra::extract::{ + cookie::{Cookie, SameSite}, + PrivateCookieJar, +}; +use cookie::time::{Duration, OffsetDateTime}; +use error::Error; +use openidconnect::{ + core::{ + CoreAuthDisplay, CoreAuthPrompt, CoreAuthenticationFlow, CoreErrorResponseType, + CoreGenderClaim, CoreJsonWebKey, CoreJsonWebKeyType, CoreJsonWebKeyUse, + CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata, + CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, + CoreTokenType, + }, + reqwest::async_http_client, + AccessTokenHash, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, + EmptyExtraTokenFields, IdTokenClaims, IdTokenFields, IssuerUrl, Nonce, OAuth2TokenResponse, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardErrorResponse, + StandardTokenResponse, TokenResponse, +}; +use serde::{Deserialize, Serialize}; + +pub use cookie::Key; + +pub mod error; + +const LOGIN_COOKIE_NAME: &str = "OIDC_LOGIN"; + +pub trait AdditionalClaims: openidconnect::AdditionalClaims + Clone + Sync + Send {} + +type OidcTokenResponse = StandardTokenResponse< + IdTokenFields< + AC, + EmptyExtraTokenFields, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + >, + CoreTokenType, +>; + +pub type OidcClient = Client< + AC, + CoreAuthDisplay, + CoreGenderClaim, + CoreJweContentEncryptionAlgorithm, + CoreJwsSigningAlgorithm, + CoreJsonWebKeyType, + CoreJsonWebKeyUse, + CoreJsonWebKey, + CoreAuthPrompt, + StandardErrorResponse, + OidcTokenResponse, + CoreTokenType, + CoreTokenIntrospectionResponse, + CoreRevocableToken, + CoreRevocationErrorResponse, +>; + +pub struct OidcApplication { + application_base: String, + issuer: IssuerUrl, + client_id: ClientId, + client_secret: Option, + scopes: Vec, + cookie_key: Key, +} +impl OidcApplication { + pub fn new( + application_base: String, + issuer: String, + client_id: String, + client_secret: Option, + scopes: Vec, + cookie_key: Key, + ) -> Self { + Self { + application_base, + issuer: IssuerUrl::new(issuer).unwrap(), + client_id: ClientId::new(client_id), + client_secret: client_secret.map(ClientSecret::new), + scopes, + cookie_key, + } + } + async fn create_client( + &self, + redirect: String, + ) -> Result, Error> { + let provider_metadata = + CoreProviderMetadata::discover_async(self.issuer.clone(), async_http_client).await?; + let client = OidcClient::::from_provider_metadata( + provider_metadata, + self.client_id.clone(), + self.client_secret.clone(), + ) + .set_redirect_uri(RedirectUrl::new(redirect)?); + Ok(client) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmptyAdditionalClaims {} +impl openidconnect::AdditionalClaims for EmptyAdditionalClaims {} +impl AdditionalClaims for EmptyAdditionalClaims {} + +pub struct ClaimsExtractor(pub IdTokenClaims); + +#[async_trait] +impl FromRequestParts for ClaimsExtractor +where + S: Send + Sync, + AC: AdditionalClaims, + OidcApplication: FromRef, +{ + type Rejection = Error; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let application: OidcApplication = OidcApplication::from_ref(state); + let client = application + .create_client(format!( + "{}/{}", + application.application_base, + parts.uri.path() + )) + .await?; + let mut jar = PrivateCookieJar::from_headers(&parts.headers, application.cookie_key); + + let login_session = jar.get(LOGIN_COOKIE_NAME); + let query = Query::::from_request_parts(parts, state) + .await + .ok(); + + if let (Some(login_session), Some(Query(query))) = (login_session, query) { + let login_session: LoginSession = serde_json::from_str(login_session.value())?; + + if login_session.csrf_token.secret() != &query.state { + return Err(Error::CsrfTokenInvalid); + } + + let token_response = client + .exchange_code(AuthorizationCode::new(query.code.to_string())) + // Set the PKCE code verifier. + .set_pkce_verifier(login_session.pkce_verifier) + .request_async(async_http_client) + .await?; + + // Extract the ID token claims after verifying its authenticity and nonce. + let id_token = token_response.id_token().ok_or(Error::IdTokenNotFound)?; + let claims = id_token.claims(&client.id_token_verifier(), &login_session.nonce)?; + + // Verify the access token hash to ensure that the access token hasn't been substituted for + // another user's. + if let Some(expected_access_token_hash) = claims.access_token_hash() { + let actual_access_token_hash = AccessTokenHash::from_token( + token_response.access_token(), + &id_token.signing_alg()?, + )?; + if actual_access_token_hash != *expected_access_token_hash { + return Err(Error::AccessTokenHashInvalid); + } + } + + Ok(Self(claims.clone())) + } else { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let (auth_url, csrf_token, nonce) = { + let mut auth = client.authorize_url( + CoreAuthenticationFlow::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ); + + for scope in application.scopes.iter() { + auth = auth.add_scope(Scope::new(scope.to_string())); + } + auth.set_pkce_challenge(pkce_challenge).url() + }; + + let login_session = LoginSession { + nonce, + csrf_token, + pkce_verifier, + }; + let login_session = serde_json::to_string(&login_session)?; + let mut cookie = Cookie::new(LOGIN_COOKIE_NAME, login_session); + cookie.set_same_site(SameSite::Lax); + cookie.set_secure(true); + cookie.set_http_only(true); + cookie.set_expires(OffsetDateTime::now_utc() + Duration::hours(1)); + jar = jar.add(cookie); + + Err(Error::Redirect(( + jar, + Redirect::temporary(auth_url.as_str()), + ))) + } + } +} + +#[derive(Debug, Deserialize)] +struct OidcQuery { + code: String, + state: String, + #[allow(dead_code)] + session_state: String, +} + +#[derive(Serialize, Deserialize)] +struct LoginSession { + nonce: Nonce, + csrf_token: CsrfToken, + pkce_verifier: PkceCodeVerifier, +}