diff --git a/Cargo.lock b/Cargo.lock index 37fd85e..2b5e8fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,7 @@ dependencies = [ "matchit", "memchr", "mime", + "multer", "percent-encoding", "pin-project-lite", "rustversion", @@ -1087,7 +1088,7 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" dependencies = [ - "spin", + "spin 0.5.2", ] [[package]] @@ -1173,6 +1174,24 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "multer" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" +dependencies = [ + "bytes", + "encoding_rs", + "futures-util", + "http", + "httparse", + "log", + "memchr", + "mime", + "spin 0.9.8", + "version_check", +] + [[package]] name = "num-bigint" version = "0.4.4" @@ -1681,7 +1700,7 @@ dependencies = [ "cc", "libc", "once_cell", - "spin", + "spin 0.5.2", "untrusted", "web-sys", "winapi", @@ -2030,6 +2049,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "spki" version = "0.7.2" diff --git a/Cargo.toml b/Cargo.toml index 9d77969..f26cc6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" tokio = { version = "1.33", features = ["full"] } tokio-util = { version="0.7", features = ["io"]} futures-util = "0.3" -axum = {version="0.6", features=["macros", "headers"]} +axum = {version="0.6", features=["macros", "headers", "multipart"]} serde = "1.0" toml = "0.8" render = { git="https://github.com/render-rs/render.rs" } diff --git a/src/error.rs b/src/error.rs index 3630708..b6e9822 100644 --- a/src/error.rs +++ b/src/error.rs @@ -32,6 +32,9 @@ pub enum Error { #[error("oidc redirect")] Oidc(Response), + + #[error("invalid multipart")] + InvalidMultipart, } impl IntoResponse for Error { @@ -44,6 +47,9 @@ impl IntoResponse for Error { } Self::ParseTtl => (StatusCode::BAD_REQUEST, "invalid ttl class\n").into_response(), Self::Oidc(response) => response.into_response(), + Self::InvalidMultipart => { + (StatusCode::BAD_REQUEST, "invalid multipart data").into_response() + } _ => { error!("{:?}", self); (StatusCode::INTERNAL_SERVER_ERROR, "internal server error\n").into_response() diff --git a/src/main.rs b/src/main.rs index 3028cc2..e637a8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,11 +7,15 @@ use std::{ }; use axum::{ - body::StreamBody, + async_trait, + body::{HttpBody, StreamBody}, debug_handler, - extract::{BodyStream, FromRef, Path, Query, State}, + extract::{BodyStream, FromRef, FromRequest, Multipart, Path, Query, State}, headers::ContentType, - http::{header, HeaderMap, StatusCode}, + http::{ + header::{self, CONTENT_TYPE}, + HeaderMap, Request, StatusCode, + }, response::{Html, IntoResponse, Redirect, Response}, routing::get, Router, TypedHeader, @@ -20,6 +24,7 @@ use axum_oidc::{ jwt::{Claims, JwtApplication}, oidc::{self, EmptyAdditionalClaims, OidcApplication, OidcExtractor}, }; +use bytes::Bytes; use chacha20::{ cipher::{KeyIvInit, StreamCipher}, ChaCha20, @@ -199,7 +204,7 @@ async fn post_item( Query(params): Query, State(app_state): State, content_type: Option>, - mut stream: BodyStream, + data: MultipartOrStream, ) -> HandlerResult { let phrase = Phrase::from_str(&phrase)?; let id = Id::from_phrase(&phrase, &app_state.id_salt); @@ -221,14 +226,34 @@ async fn post_item( let mut etag_hasher = Sha3_256::new(); let mut size = 0; - while let Some(chunk) = stream.next().await { - let mut buf = chunk.unwrap_or_default().to_vec(); - etag_hasher.update(&buf); - size += buf.len() as u64; - cipher.apply_keystream(&mut buf); - writer.write_all(&buf).await?; + match data { + MultipartOrStream::Stream(mut stream) => { + while let Some(chunk) = stream.next().await { + let mut buf = chunk.unwrap_or_default().to_vec(); + etag_hasher.update(&buf); + size += buf.len() as u64; + cipher.apply_keystream(&mut buf); + writer.write_all(&buf).await?; + } + } + MultipartOrStream::Multipart(mut multipart) => { + while let Some(mut field) = multipart + .next_field() + .await + .map_err(|_| Error::InvalidMultipart)? + { + if field.name().unwrap_or_default() == "file" { + while let Some(chunk) = field.chunk().await.unwrap_or_default() { + let mut buf = chunk.to_vec(); + etag_hasher.update(&buf); + size += buf.len() as u64; + cipher.apply_keystream(&mut buf); + writer.write_all(&buf).await?; + } + } + } + } } - writer.flush().await?; let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); @@ -325,3 +350,41 @@ async fn get_item( Ok((StatusCode::OK, headers, body).into_response()) } } + +enum MultipartOrStream { + Multipart(Multipart), + Stream(BodyStream), +} + +#[async_trait] +impl FromRequest for MultipartOrStream +where + B: Send + 'static + HttpBody, + S: Send + Sync, + Bytes: From<::Data>, + ::Error: + Send + Sync + Into>, +{ + type Rejection = Response; + + async fn from_request(req: Request, state: &S) -> Result { + let is_multipart = req + .headers() + .get(CONTENT_TYPE) + .map(|x| x == "multipart/form-data") + .unwrap_or_default(); + if is_multipart { + Ok(Self::Multipart( + Multipart::from_request(req, state) + .await + .map_err(|x| x.into_response())?, + )) + } else { + Ok(Self::Stream( + BodyStream::from_request(req, state) + .await + .map_err(|x| x.into_response())?, + )) + } + } +}