bin/src/web_util.rs

134 lines
4.2 KiB
Rust
Raw Normal View History

2023-10-18 16:14:44 +02:00
use std::{
borrow::Borrow,
pin::Pin,
task::{Context, Poll},
};
use bytes::{Bytes, BytesMut};
use chacha20::{
cipher::{KeyIvInit, StreamCipher},
ChaCha20,
};
use futures_util::Stream;
use log::{debug, warn};
use pin_project_lite::pin_project;
use sha3::{Digest, Sha3_256};
use tokio::io::AsyncRead;
use tokio_util::io::poll_read_buf;
use crate::{
metadata::Metadata,
util::{Id, Key, Nonce},
};
pin_project! {
pub(crate) struct DecryptingStream<R> {
#[pin]
reader: Option<R>,
buf: BytesMut,
// chunk size
capacity: usize,
// chacha20 cipher
cipher: ChaCha20,
// hasher to verify file integrity
hasher: Sha3_256,
// hash to verify against
target_hash: String,
// id of the file for logging purposes
id: Id,
// total file size
size: u64,
// current position of the "reading head"
progress: u64
}
}
impl<R: AsyncRead> DecryptingStream<R> {
pub(crate) fn new(reader: R, id: Id, metadata: &Metadata, key: &Key, nonce: &Nonce) -> Self {
let cipher = ChaCha20::new(key.borrow(), nonce.borrow());
Self {
reader: Some(reader),
buf: BytesMut::new(),
capacity: 1 << 22, // 4 MiB
cipher,
hasher: Sha3_256::new(),
target_hash: metadata.etag.clone().unwrap_or_default(),
id,
size: metadata.size.unwrap_or_default(),
progress: 0,
}
}
}
impl<R: AsyncRead> Stream for DecryptingStream<R> {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(*this.capacity);
}
match poll_read_buf(reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
debug!("failed to send bin {}", this.id);
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
if self.progress_check() == DecryptingStreamProgress::Failed {
// The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail
self.project().reader.set(None);
return Poll::Ready(None);
};
self.project().reader.set(None);
Poll::Ready(None)
}
Poll::Ready(Ok(n)) => {
let mut chunk = this.buf.split();
// decrypt the chunk using chacha
this.cipher.apply_keystream(&mut chunk);
// update the sha3 hasher
this.hasher.update(&chunk);
// track progress
*this.progress += n as u64;
if self.progress_check() == DecryptingStreamProgress::Failed {
// The hash is invalid, the file has been tampered with. Close reader and stream causing the download to fail
warn!("bin {} is corrupted! transmission failed", self.id);
self.project().reader.set(None);
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
impl<R: AsyncRead> DecryptingStream<R> {
/// checks if the hash is correct when the last byte has been read
fn progress_check(&self) -> DecryptingStreamProgress {
if self.progress >= self.size {
let hash = hex::encode(self.hasher.clone().finalize());
if hash != self.target_hash {
DecryptingStreamProgress::Failed
} else {
DecryptingStreamProgress::Finished
}
} else {
DecryptingStreamProgress::Running
}
}
}
#[derive(PartialEq, Eq)]
enum DecryptingStreamProgress {
Finished,
Failed,
Running,
}