jwt as extractor

This commit is contained in:
Paul Zinselmeyer 2023-10-20 00:26:59 +02:00
parent 89da8cc07f
commit 3b9438d1f3
Signed by: pfzetto
GPG key ID: 4EEF46A5B276E648
4 changed files with 39 additions and 76 deletions

2
Cargo.lock generated
View file

@ -162,14 +162,12 @@ dependencies = [
"async-trait",
"axum",
"axum-extra",
"futures-util",
"jsonwebtoken",
"openidconnect",
"reqwest",
"serde",
"serde_json",
"thiserror",
"tower",
]
[[package]]

View file

@ -18,10 +18,8 @@ serde = "1.0"
serde_json = "1.0"
jsonwebtoken = {version="^8.3", optional=true}
tower = {version="^0.4", optional=true}
futures-util = {version="^0.3",optional=true}
[features]
default = [ "jwt", "oidc" ]
oidc = [ "openidconnect", "axum-extra" ]
jwt = [ "tower", "jsonwebtoken", "futures-util", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ]
jwt = [ "jsonwebtoken", "reqwest/json", "reqwest/rustls-tls", "serde/derive" ]

View file

@ -65,6 +65,10 @@ pub enum Error {
#[cfg(feature = "jwt")]
#[error("jsonwebtoken: {0}")]
JsonWebToken(#[from] jsonwebtoken::errors::Error),
#[cfg(feature = "jwt")]
#[error("jwt invalid")]
JwtInvalid,
}
impl IntoResponse for Error {
@ -78,6 +82,10 @@ impl IntoResponse for Error {
#[cfg(feature = "oidc")]
Self::Redirect(redirect) => redirect.into_response(),
#[cfg(feature = "jwt")]
Self::JwtInvalid => (StatusCode::UNAUTHORIZED, "access token invalid").into_response(),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(),
}
}

View file

@ -1,18 +1,12 @@
use std::{
marker::PhantomData,
task::{Context, Poll},
};
use std::marker::PhantomData;
use async_trait::async_trait;
use axum::{
body::Body,
http::Request,
response::{IntoResponse, Response},
extract::{FromRef, FromRequestParts},
http::request::Parts,
};
use futures_util::future::BoxFuture;
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use reqwest::StatusCode;
use serde::Deserialize;
use tower::{Layer, Service};
use serde::{de::DeserializeOwned, Deserialize};
use crate::error::Error;
@ -35,10 +29,8 @@ pub struct Claims<A: Clone> {
}
#[derive(Clone)]
pub struct JwtLayer<A: Clone> {
algorithm: Algorithm,
issuer: Vec<String>,
audience: Vec<String>,
pub struct JwtApplication<A: Clone> {
validation: Validation,
pubkey: DecodingKey,
_a: PhantomData<A>,
}
@ -48,7 +40,7 @@ struct IssuerDiscovery {
public_key: String,
}
impl<A: Clone> JwtLayer<A> {
impl<A: Clone> JwtApplication<A> {
pub async fn new(issuer: String, audience: String) -> Result<Self, Error> {
let issuer_key = reqwest::get(&issuer)
.await?
@ -63,78 +55,45 @@ impl<A: Clone> JwtLayer<A> {
let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer]);
validation.set_audience(&[audience]);
validation.validate_nbf = true;
Ok(Self {
algorithm: Algorithm::RS256,
issuer: vec![issuer],
audience: vec![audience],
validation,
pubkey,
_a: PhantomData,
})
}
}
impl<S, A: Clone> Layer<S> for JwtLayer<A> {
type Service = JwtService<S, A>;
fn layer(&self, inner: S) -> Self::Service {
let mut validation = Validation::new(self.algorithm);
validation.set_issuer(&self.issuer);
validation.set_audience(&self.audience);
validation.validate_nbf = true;
JwtService {
validation,
pubkey: self.pubkey.clone(),
inner,
_a: PhantomData,
}
}
}
#[derive(Clone)]
pub struct JwtService<S, A: Clone> {
validation: Validation,
pubkey: DecodingKey,
inner: S,
_a: PhantomData<A>,
}
impl<S, A: Clone> Service<Request<Body>> for JwtService<S, A>
#[async_trait]
impl<S, A> FromRequestParts<S> for Claims<A>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
A: Clone + for<'a> Deserialize<'a> + 'static + Sync + Send,
S: Send + Sync,
A: Clone + DeserializeOwned,
JwtApplication<A>: FromRef<S>,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
type Rejection = Error;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let application: JwtApplication<A> = JwtApplication::from_ref(state);
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let token = req
.headers()
let token = parts
.headers
.get("Authorization")
.and_then(|x| x.to_str().ok())
.map(|x| x.chars().skip(7).collect::<String>());
let token =
token.and_then(|x| decode::<Claims<A>>(&x, &self.pubkey, &self.validation).ok());
let token_exists = token.is_some();
let token = token.and_then(|x| {
decode::<Claims<A>>(&x, &application.pubkey, &application.validation).ok()
});
if let Some(token) = token {
req.extensions_mut().insert(token.claims);
}
let future = self.inner.call(req);
Box::pin(async move {
if token_exists {
let response: Response = future.await?;
Ok(response)
Ok(token.claims)
} else {
Ok((StatusCode::UNAUTHORIZED, "access token invalid").into_response())
Err(Error::JwtInvalid)
}
})
}
}