diff --git a/Cargo.lock b/Cargo.lock index 38531ff..79fc346 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -162,14 +162,12 @@ dependencies = [ "async-trait", "axum", "axum-extra", - "futures-util", "jsonwebtoken", "openidconnect", "reqwest", "serde", "serde_json", "thiserror", - "tower", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 636d575..3b40a2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,10 +18,8 @@ serde = "1.0" serde_json = "1.0" jsonwebtoken = {version="^8.3", optional=true} -tower = {version="^0.4", optional=true} -futures-util = {version="^0.3",optional=true} [features] default = [ "jwt", "oidc" ] oidc = [ "openidconnect", "axum-extra" ] -jwt = [ "tower", "jsonwebtoken", "futures-util", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ] +jwt = [ "jsonwebtoken", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ] diff --git a/src/error.rs b/src/error.rs index 6c4efcd..cffb029 100644 --- a/src/error.rs +++ b/src/error.rs @@ -65,6 +65,10 @@ pub enum Error { #[cfg(feature = "jwt")] #[error("jsonwebtoken: {0}")] JsonWebToken(#[from] jsonwebtoken::errors::Error), + + #[cfg(feature = "jwt")] + #[error("jwt invalid")] + JwtInvalid, } impl IntoResponse for Error { @@ -78,6 +82,10 @@ impl IntoResponse for Error { #[cfg(feature = "oidc")] Self::Redirect(redirect) => redirect.into_response(), + + #[cfg(feature = "jwt")] + Self::JwtInvalid => (StatusCode::UNAUTHORIZED, "access token invalid").into_response(), + _ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(), } } diff --git a/src/jwt.rs b/src/jwt.rs index fd03005..ce89a22 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -1,18 +1,12 @@ -use std::{ - marker::PhantomData, - task::{Context, Poll}, -}; +use std::marker::PhantomData; +use async_trait::async_trait; use axum::{ - body::Body, - http::Request, - response::{IntoResponse, Response}, + extract::{FromRef, FromRequestParts}, + http::request::Parts, }; -use futures_util::future::BoxFuture; use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; -use reqwest::StatusCode; -use serde::Deserialize; -use tower::{Layer, Service}; +use serde::{de::DeserializeOwned, Deserialize}; use crate::error::Error; @@ -35,10 +29,8 @@ pub struct Claims { } #[derive(Clone)] -pub struct JwtLayer { - algorithm: Algorithm, - issuer: Vec, - audience: Vec, +pub struct JwtApplication { + validation: Validation, pubkey: DecodingKey, _a: PhantomData, } @@ -48,7 +40,7 @@ struct IssuerDiscovery { public_key: String, } -impl JwtLayer { +impl JwtApplication { pub async fn new(issuer: String, audience: String) -> Result { let issuer_key = reqwest::get(&issuer) .await? @@ -63,78 +55,45 @@ impl JwtLayer { let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?; + let mut validation = Validation::new(Algorithm::RS256); + validation.set_issuer(&[issuer]); + validation.set_audience(&[audience]); + validation.validate_nbf = true; + Ok(Self { - algorithm: Algorithm::RS256, - issuer: vec![issuer], - audience: vec![audience], + validation, pubkey, _a: PhantomData, }) } } -impl Layer for JwtLayer { - type Service = JwtService; - - fn layer(&self, inner: S) -> Self::Service { - let mut validation = Validation::new(self.algorithm); - validation.set_issuer(&self.issuer); - validation.set_audience(&self.audience); - validation.validate_nbf = true; - JwtService { - validation, - pubkey: self.pubkey.clone(), - inner, - _a: PhantomData, - } - } -} - -#[derive(Clone)] -pub struct JwtService { - validation: Validation, - pubkey: DecodingKey, - inner: S, - _a: PhantomData, -} - -impl Service> for JwtService +#[async_trait] +impl FromRequestParts for Claims where - S: Service, Response = Response> + Send + 'static, - S::Future: Send + 'static, - A: Clone + for<'a> Deserialize<'a> + 'static + Sync + Send, + S: Send + Sync, + A: Clone + DeserializeOwned, + JwtApplication: FromRef, { - type Response = S::Response; - type Error = S::Error; - type Future = BoxFuture<'static, Result>; + type Rejection = Error; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let application: JwtApplication = JwtApplication::from_ref(state); - fn call(&mut self, mut req: Request) -> Self::Future { - let token = req - .headers() + let token = parts + .headers .get("Authorization") .and_then(|x| x.to_str().ok()) .map(|x| x.chars().skip(7).collect::()); - let token = - token.and_then(|x| decode::>(&x, &self.pubkey, &self.validation).ok()); - let token_exists = token.is_some(); + let token = token.and_then(|x| { + decode::>(&x, &application.pubkey, &application.validation).ok() + }); if let Some(token) = token { - req.extensions_mut().insert(token.claims); + Ok(token.claims) + } else { + Err(Error::JwtInvalid) } - - let future = self.inner.call(req); - Box::pin(async move { - if token_exists { - let response: Response = future.await?; - Ok(response) - } else { - Ok((StatusCode::UNAUTHORIZED, "access token invalid").into_response()) - } - }) } }