change to axum-oidc

This commit is contained in:
Paul Zinselmeyer 2023-12-19 10:56:47 +01:00
parent 04a6d296b5
commit c626995a32
6 changed files with 790 additions and 498 deletions

974
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -7,11 +7,19 @@
]
},
"locked": {
<<<<<<< HEAD
"lastModified": 1699548976,
"narHash": "sha256-xnpxms0koM8mQpxIup9JnT0F7GrKdvv0QvtxvRuOYR4=",
"owner": "ipetkov",
"repo": "crane",
"rev": "6849911446e18e520970cc6b7a691e64ee90d649",
=======
"lastModified": 1702956644,
"narHash": "sha256-6XxZSkhb/OkxIx705RHTTLYZ2qemmEC7tODD8f21gKw=",
"owner": "ipetkov",
"repo": "crane",
"rev": "537ebb11db883f9076e37d83e3c7ee69a4abb48c",
>>>>>>> 881c599 (change to axum-oidc)
"type": "github"
},
"original": {
@ -25,11 +33,11 @@
"systems": "systems"
},
"locked": {
"lastModified": 1694529238,
"narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=",
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "ff7b65b44d01cf9ba6a71320833626af21126384",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
@ -40,11 +48,19 @@
},
"nixpkgs": {
"locked": {
<<<<<<< HEAD
"lastModified": 1699099776,
"narHash": "sha256-X09iKJ27mGsGambGfkKzqvw5esP1L/Rf8H3u3fCqIiU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "85f1ba3e51676fa8cc604a3d863d729026a6b8eb",
=======
"lastModified": 1702830618,
"narHash": "sha256-lvhwIvRwhOLgzbRuYkqHy4M5cQHYs4ktL6/hyuBS6II=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "91a00709aebb3602f172a0bf47ba1ef013e34835",
>>>>>>> 881c599 (change to axum-oidc)
"type": "github"
},
"original": {
@ -72,11 +88,19 @@
]
},
"locked": {
<<<<<<< HEAD
"lastModified": 1699582387,
"narHash": "sha256-sPmUXPDl+cEi+zFtM5lnAs7dWOdRn0ptZ4a/qHwvNDk=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "41f7b0618052430d3a050e8f937030d00a2fcced",
=======
"lastModified": 1702952173,
"narHash": "sha256-24kUnTZgXP5B/fs1/f61tJuHyFrJ8824rn1B/0hL1og=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "20fd62b0891707a1db8117d09fc3e65a1cd0f6d7",
>>>>>>> 881c599 (change to axum-oidc)
"type": "github"
},
"original": {

View file

@ -6,10 +6,10 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1.33", features = ["full"] }
tokio = { version = "1.35", features = ["full"] }
tokio-util = { version="0.7", features = ["io"]}
futures-util = "0.3"
axum = {version="0.6", features=["macros", "headers", "multipart"]}
axum = {version="0.7", features=["macros", "multipart"]}
serde = "1.0"
toml = "0.8"
duration-str = "0.7.0"
@ -18,11 +18,13 @@ thiserror = "1.0"
rand = "0.8"
dotenvy = "0.15"
markdown = "0.3"
axum_oidc = {git="https://git2.zettoit.eu/pfz4/axum_oidc"}
axum-oidc = "0.2.1"
tower-sessions = "0.7.0"
tower = "0.4.13"
log = "0.4"
env_logger = "0.10"
sailfish = "0.8.3"
tower-http = { version="0.4.4", features=["fs"], default-features=false }
tower-http = { version="0.5", features=["fs"], default-features=false }
prometheus-client = "0.22.0"
chacha20 = "0.9"
@ -30,3 +32,6 @@ sha3 = "0.10"
hex = "0.4"
bytes = "1.5"
pin-project-lite = "0.2"
reqwest = { version="0.11", default_features=false, features=["rustls-tls", "json"] }
jsonwebtoken = "9.2.0"

View file

@ -2,7 +2,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use log::error;
use log::{debug, error};
#[derive(Debug, thiserror::Error)]
pub enum Error {
@ -44,10 +44,20 @@ pub enum Error {
#[error("prometheus: {0:?}")]
Prometheus(std::fmt::Error),
#[error("jwt invalid")]
JwtInvalid,
#[error("reqwest {0:?}")]
Reqwest(#[from] reqwest::Error),
#[error("reqwest {0:?}")]
JwtDecode(#[from] jsonwebtoken::errors::Error),
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
debug!("{:?}", &self);
match self {
Self::PhraseInvalid => (StatusCode::BAD_REQUEST, "url is not valid\n").into_response(),
Self::BinNotFound => (StatusCode::NOT_FOUND, "bin does not exist\n").into_response(),
@ -60,6 +70,7 @@ impl IntoResponse for Error {
}
Self::InvalidTtl => (StatusCode::BAD_REQUEST, "invalid ttl specified").into_response(),
Self::Unauthorized => (StatusCode::UNAUTHORIZED, "unauthorized\n").into_response(),
Self::JwtInvalid => (StatusCode::BAD_REQUEST, "jwt token is invalid\n").into_response(),
Self::Forbidden => (StatusCode::FORBIDDEN, "forbidden\n").into_response(),
_ => {
error!("{:?}", self);

99
server/src/jwt.rs Normal file
View file

@ -0,0 +1,99 @@
use std::marker::PhantomData;
use axum::{
async_trait,
extract::{FromRef, FromRequestParts},
http::request::Parts,
};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{de::DeserializeOwned, Deserialize};
use crate::error::Error;
#[derive(Debug, Clone, Deserialize)]
pub struct Claims<A: Clone> {
pub aud: Vec<String>,
pub exp: usize,
pub iat: usize,
pub iss: String,
pub sub: String,
pub azp: String,
pub name: String,
pub preferred_username: String,
pub given_name: String,
pub family_name: String,
pub email: String,
#[serde(flatten)]
pub additional: A,
}
#[derive(Clone)]
pub struct JwtApplication<A: Clone> {
validation: Validation,
pubkey: DecodingKey,
_a: PhantomData<A>,
}
#[derive(Deserialize)]
struct IssuerDiscovery {
public_key: String,
}
impl<A: Clone + DeserializeOwned> JwtApplication<A> {
pub async fn new(issuer: String, audience: String) -> Result<Self, Error> {
let issuer_key = reqwest::get(&issuer)
.await?
.json::<IssuerDiscovery>()
.await?
.public_key;
let pem = format!(
"-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
issuer_key
);
let pubkey = DecodingKey::from_rsa_pem(pem.as_bytes())?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer]);
validation.set_audience(&[audience]);
validation.validate_nbf = true;
Ok(Self {
validation,
pubkey,
_a: PhantomData,
})
}
}
#[async_trait]
impl<S, A> FromRequestParts<S> for Claims<A>
where
S: Send + Sync,
A: Clone + DeserializeOwned,
JwtApplication<A>: FromRef<S>,
{
type Rejection = Error;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let application: JwtApplication<A> = JwtApplication::from_ref(state);
let token = parts
.headers
.get("Authorization")
.and_then(|x| x.to_str().ok())
.map(|x| x.chars().skip(7).collect::<String>());
let token = token.and_then(|x| {
decode::<Claims<A>>(&x, &application.pubkey, &application.validation).ok()
});
if let Some(token) = token {
Ok(token.claims)
} else {
Err(Error::JwtInvalid)
}
}
}

View file

@ -1,5 +1,10 @@
#![deny(clippy::unwrap_used)]
use axum_oidc::{
error::{ExtractorError, MiddlewareError},
EmptyAdditionalClaims, OidcAuthLayer, OidcClaims, OidcLoginLayer,
};
use duration_str::{deserialize_option_duration, parse_std};
use jwt::JwtApplication;
use prometheus_client::{
encoding::EncodeLabelSet,
metrics::{counter::Counter, family::Family, gauge::Gauge, histogram::Histogram},
@ -15,23 +20,24 @@ use std::{
};
use tower_http::services::ServeDir;
use tower::ServiceBuilder;
use tower_sessions::{cookie::SameSite, MemoryStore, SessionManagerLayer};
use axum::{
async_trait,
body::{HttpBody, StreamBody},
extract::{BodyStream, FromRef, FromRequest, Multipart, Path, Query, State},
headers::{ContentType, Range},
body::{Body, HttpBody},
debug_handler,
error_handling::HandleErrorLayer,
extract::{FromRef, FromRequest, Multipart, Path, Query, State},
http::{
header::{self, CONTENT_TYPE},
HeaderMap, HeaderValue, Request, StatusCode,
HeaderMap, HeaderValue, Request, StatusCode, Uri,
},
response::{Html, IntoResponse, Redirect, Response},
routing::get,
Router, TypedHeader,
};
use axum_oidc::{
jwt::{Claims, JwtApplication},
oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor},
routing::{get, post},
BoxError, Router,
};
use bytes::Bytes;
use chacha20::{
cipher::{KeyIvInit, StreamCipher},
@ -45,17 +51,20 @@ use sha3::{Digest, Sha3_256};
use tokio::{
fs::{self, File},
io::{AsyncWriteExt, BufReader, BufWriter},
net::TcpListener,
};
use util::{IdSalt, KeySalt};
use crate::{
error::Error,
jwt::Claims,
metadata::Metadata,
util::{Id, Key, Nonce, Phrase},
web_util::DecryptingStream,
};
mod error;
mod garbage_collector;
mod jwt;
mod metadata;
mod util;
mod web_util;
@ -69,9 +78,8 @@ type HandlerResult<T> = Result<T, Error>;
#[derive(Clone)]
pub struct AppState {
application_base: String,
oidc_application: OidcApplication<EmptyAdditionalClaims>,
jwt_application: JwtApplication<EmptyAdditionalClaims>,
data: String,
jwt_application: JwtApplication<EmptyAdditionalClaims>,
key_salt: KeySalt,
id_salt: IdSalt,
garbage_collector: GarbageCollector,
@ -91,12 +99,6 @@ pub struct BinDownloadLabels {
bin: Id,
}
impl FromRef<AppState> for OidcApplication<EmptyAdditionalClaims> {
fn from_ref(input: &AppState) -> Self {
input.oidc_application.clone()
}
}
impl FromRef<AppState> for JwtApplication<EmptyAdditionalClaims> {
fn from_ref(input: &AppState) -> Self {
input.jwt_application.clone()
@ -124,25 +126,42 @@ async fn main() {
.map(|x| x.to_owned())
.collect::<Vec<_>>();
let oidc_application = OidcApplication::<EmptyAdditionalClaims>::create(
application_base
.parse()
.expect("valid APPLICATION_BASE url"),
issuer.to_string(),
client_id.to_string(),
client_secret.to_owned(),
scopes.clone(),
oidc::Key::generate(),
)
.await
.expect("Oidc Authentication Client");
let session_store = MemoryStore::default();
let session_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: BoxError| async {
dbg!("{:?}", e);
StatusCode::BAD_REQUEST
}))
.layer(SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax));
let jwt_application = JwtApplication::new(
issuer.to_string(),
env::var("AUDIENCE").expect("AUDIENCE env var"),
)
.await
.expect("Jwt Authentication Client");
let oidc_login_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
dbg!("{:?}", &e);
e.into_response()
}))
.layer(OidcLoginLayer::<EmptyAdditionalClaims>::new());
let oidc_auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|e: MiddlewareError| async {
dbg!("{:?}", &e);
e.into_response()
}))
.layer(
OidcAuthLayer::<EmptyAdditionalClaims>::discover_client(
Uri::from_maybe_shared(application_base.clone()).expect("valid base url"),
issuer.clone(),
client_id,
client_secret,
scopes,
)
.await
.expect("Oidc Layer"),
);
let jwt_application =
JwtApplication::new(issuer, env::var("AUDIENCE").expect("AUDIENCE env var"))
.await
.expect("Jwt Authentication Client");
let data_path = env::var("DATA_PATH").expect("DATA_PATH env var");
@ -174,9 +193,8 @@ async fn main() {
let state: AppState = AppState {
application_base,
oidc_application,
jwt_application,
data: data_path,
jwt_application,
key_salt: KeySalt::from_str(&env::var("KEY_SALT").expect("KEY_SALT env var"))
.expect("KEY SALT valid hex"),
id_salt: IdSalt::from_str(&env::var("ID_SALT").expect("ID_SALT env var"))
@ -192,31 +210,30 @@ async fn main() {
let app = Router::new()
.route("/", get(get_index))
.route(
"/:id",
get(get_item)
.post(upload_bin)
.put(upload_bin)
.delete(delete_bin),
)
.route("/:id", post(upload_bin).put(upload_bin).delete(delete_bin))
.route("/:id/delete", get(delete_bin_interactive).post(delete_bin))
.layer(oidc_login_service)
.route("/:id", get(get_item))
.route("/metrics", get(metrics))
.nest_service("/static", ServeDir::new("static"))
.with_state(state);
axum::Server::bind(&"[::]:8080".parse().expect("valid listen address"))
.serve(app.into_make_service())
.with_state(state)
.layer(oidc_auth_service)
.layer(session_service);
let listener = TcpListener::bind("[::]:8080")
.await
.expect("Axum Server");
.expect("valid listen address");
axum::serve(listener, app).await.expect("axum server");
}
async fn get_index(
State(app_state): State<AppState>,
oidc_extractor: Result<OidcExtractor<EmptyAdditionalClaims>, axum_oidc::error::Error>,
oidc_extractor: Result<OidcClaims<EmptyAdditionalClaims>, ExtractorError>,
jwt_claims: Option<Claims<EmptyAdditionalClaims>>,
) -> Result<impl IntoResponse, Error> {
) -> HandlerResult<impl IntoResponse> {
let subject = match (oidc_extractor, jwt_claims) {
(_, Some(claims)) => claims.sub.to_string(),
(Ok(oidc), None) => oidc.claims.subject().to_string(),
(Ok(oidc), None) => oidc.0.subject().to_string(),
(Err(err), None) => return Err(Error::Oidc(err.into_response())),
};
@ -252,12 +269,12 @@ async fn get_index(
async fn delete_bin(
Path(phrase): Path<String>,
State(app_state): State<AppState>,
oidc_extractor: Result<OidcExtractor<EmptyAdditionalClaims>, axum_oidc::error::Error>,
oidc_extractor: Result<OidcClaims<EmptyAdditionalClaims>, ExtractorError>,
jwt_claims: Option<Claims<EmptyAdditionalClaims>>,
) -> HandlerResult<impl IntoResponse> {
let subject = match (oidc_extractor, jwt_claims) {
(_, Some(claims)) => claims.sub.to_string(),
(Ok(oidc), None) => oidc.claims.subject().to_string(),
(Ok(oidc), None) => oidc.0.subject().to_string(),
(Err(_), None) => return Err(Error::Unauthorized),
};
@ -290,7 +307,7 @@ async fn delete_bin(
async fn delete_bin_interactive(
_: Path<String>,
_: OidcExtractor<EmptyAdditionalClaims>,
_: OidcClaims<EmptyAdditionalClaims>,
) -> HandlerResult<impl IntoResponse> {
Ok(Html(DeleteTemplate.render_once()?))
}
@ -304,7 +321,7 @@ async fn upload_bin(
Path(phrase): Path<String>,
Query(params): Query<PostQuery>,
State(app_state): State<AppState>,
content_type: Option<TypedHeader<ContentType>>,
headers: HeaderMap,
data: MultipartOrStream,
) -> HandlerResult<impl IntoResponse> {
let phrase = Phrase::from_str(&phrase)?;
@ -334,7 +351,8 @@ async fn upload_bin(
.unwrap_or(Duration::from_secs(24 * 3600));
match data {
MultipartOrStream::Stream(mut stream) => {
MultipartOrStream::Stream(stream) => {
let mut stream = stream.into_data_stream();
while let Some(chunk) = stream.next().await {
let mut buf = chunk.unwrap_or_default().to_vec();
etag_hasher.update(&buf);
@ -342,8 +360,14 @@ async fn upload_bin(
cipher.apply_keystream(&mut buf);
writer.write_all(&buf).await?;
}
metadata.content_type = match content_type {
Some(content_type) => Some(content_type.to_string()),
metadata.content_type = match headers.get(CONTENT_TYPE) {
Some(content_type) => Some(
content_type
.to_owned()
.to_str()
.unwrap_or_default()
.to_string(),
),
None => Some("application/octet-stream".to_string()),
};
}
@ -444,7 +468,7 @@ async fn get_item(
let file = File::open(&path).await?;
let reader = BufReader::new(file);
let body = StreamBody::new(DecryptingStream::new(
let body = Body::from_stream(DecryptingStream::new(
reader,
id.clone(),
&metadata,
@ -500,21 +524,14 @@ async fn metrics(State(app_state): State<AppState>) -> HandlerResult<impl IntoRe
enum MultipartOrStream {
Multipart(Multipart),
Stream(BodyStream),
Stream(Body),
}
#[async_trait]
impl<S, B> FromRequest<S, B> for MultipartOrStream
where
B: Send + 'static + HttpBody,
S: Send + Sync,
Bytes: From<<B as HttpBody>::Data>,
<B as HttpBody>::Error:
Send + Sync + Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>,
{
impl<S: Sync + Send> FromRequest<S> for MultipartOrStream {
type Rejection = Response;
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
async fn from_request(req: Request<Body>, state: &S) -> Result<Self, Self::Rejection> {
let is_multipart = req
.headers()
.get(CONTENT_TYPE)
@ -532,7 +549,7 @@ where
))
} else {
Ok(Self::Stream(
BodyStream::from_request(req, state)
Body::from_request(req, state)
.await
.map_err(|x| x.into_response())?,
))