use core::{
cell::UnsafeCell,
future::poll_fn,
mem::MaybeUninit,
pin::Pin,
ptr,
sync::atomic::{fence, Ordering},
task::{Poll, Waker},
};
#[doc(hidden)]
pub use critical_section;
use heapless::Deque;
use rtic_common::waker_registration::CriticalSectionWakerRegistration as WakerRegistration;
use rtic_common::{
dropper::OnDrop,
wait_queue::{Link, WaitQueue},
};
#[cfg(feature = "defmt-03")]
use crate::defmt;
pub struct Channel<T, const N: usize> {
freeq: UnsafeCell<Deque<u8, N>>,
readyq: UnsafeCell<Deque<u8, N>>,
receiver_waker: WakerRegistration,
slots: [UnsafeCell<MaybeUninit<T>>; N],
wait_queue: WaitQueue,
receiver_dropped: UnsafeCell<bool>,
num_senders: UnsafeCell<usize>,
}
unsafe impl<T, const N: usize> Send for Channel<T, N> {}
unsafe impl<T, const N: usize> Sync for Channel<T, N> {}
struct UnsafeAccess<'a, const N: usize> {
freeq: &'a mut Deque<u8, N>,
readyq: &'a mut Deque<u8, N>,
receiver_dropped: &'a mut bool,
num_senders: &'a mut usize,
}
impl<T, const N: usize> Default for Channel<T, N> {
fn default() -> Self {
Self::new()
}
}
impl<T, const N: usize> Channel<T, N> {
const _CHECK: () = assert!(N < 256, "This queue support a maximum of 255 entries");
pub const fn new() -> Self {
Self {
freeq: UnsafeCell::new(Deque::new()),
readyq: UnsafeCell::new(Deque::new()),
receiver_waker: WakerRegistration::new(),
slots: [const { UnsafeCell::new(MaybeUninit::uninit()) }; N],
wait_queue: WaitQueue::new(),
receiver_dropped: UnsafeCell::new(false),
num_senders: UnsafeCell::new(0),
}
}
pub fn split(&mut self) -> (Sender<'_, T, N>, Receiver<'_, T, N>) {
for idx in 0..N as u8 {
assert!(!self.freeq.get_mut().is_full());
unsafe {
self.freeq.get_mut().push_back_unchecked(idx);
}
}
assert!(self.freeq.get_mut().is_full());
*self.num_senders.get_mut() = 1;
(Sender(self), Receiver(self))
}
fn access<'a>(&'a self, _cs: critical_section::CriticalSection) -> UnsafeAccess<'a, N> {
unsafe {
UnsafeAccess {
freeq: &mut *self.freeq.get(),
readyq: &mut *self.readyq.get(),
receiver_dropped: &mut *self.receiver_dropped.get(),
num_senders: &mut *self.num_senders.get(),
}
}
}
}
#[macro_export]
macro_rules! make_channel {
($type:ty, $size:expr) => {{
static mut CHANNEL: $crate::channel::Channel<$type, $size> =
$crate::channel::Channel::new();
static CHECK: $crate::portable_atomic::AtomicU8 = $crate::portable_atomic::AtomicU8::new(0);
$crate::channel::critical_section::with(|_| {
if CHECK.load(::core::sync::atomic::Ordering::Relaxed) != 0 {
panic!("call to the same `make_channel` instance twice");
}
CHECK.store(1, ::core::sync::atomic::Ordering::Relaxed);
});
unsafe { CHANNEL.split() }
}};
}
#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
pub struct NoReceiver<T>(pub T);
#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
pub enum TrySendError<T> {
NoReceiver(T),
Full(T),
}
impl<T> core::fmt::Debug for NoReceiver<T>
where
T: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "NoReceiver({:?})", self.0)
}
}
impl<T> core::fmt::Debug for TrySendError<T>
where
T: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TrySendError::NoReceiver(v) => write!(f, "NoReceiver({v:?})"),
TrySendError::Full(v) => write!(f, "Full({v:?})"),
}
}
}
impl<T> PartialEq for TrySendError<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(TrySendError::NoReceiver(v1), TrySendError::NoReceiver(v2)) => v1.eq(v2),
(TrySendError::NoReceiver(_), TrySendError::Full(_)) => false,
(TrySendError::Full(_), TrySendError::NoReceiver(_)) => false,
(TrySendError::Full(v1), TrySendError::Full(v2)) => v1.eq(v2),
}
}
}
pub struct Sender<'a, T, const N: usize>(&'a Channel<T, N>);
unsafe impl<'a, T, const N: usize> Send for Sender<'a, T, N> {}
#[derive(Clone)]
struct LinkPtr(*mut Option<Link<Waker>>);
impl LinkPtr {
unsafe fn get(&mut self) -> &mut Option<Link<Waker>> {
&mut *self.0
}
}
unsafe impl Send for LinkPtr {}
unsafe impl Sync for LinkPtr {}
impl<'a, T, const N: usize> core::fmt::Debug for Sender<'a, T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Sender")
}
}
#[cfg(feature = "defmt-03")]
impl<'a, T, const N: usize> defmt::Format for Sender<'a, T, N> {
fn format(&self, f: defmt::Formatter) {
defmt::write!(f, "Sender",)
}
}
impl<'a, T, const N: usize> Sender<'a, T, N> {
#[inline(always)]
fn send_footer(&mut self, idx: u8, val: T) {
unsafe {
ptr::write(
self.0.slots.get_unchecked(idx as usize).get() as *mut T,
val,
)
}
critical_section::with(|cs| {
assert!(!self.0.access(cs).readyq.is_full());
unsafe { self.0.access(cs).readyq.push_back_unchecked(idx) }
});
fence(Ordering::SeqCst);
self.0.receiver_waker.wake();
}
pub fn try_send(&mut self, val: T) -> Result<(), TrySendError<T>> {
if !self.0.wait_queue.is_empty() {
return Err(TrySendError::Full(val));
}
if self.is_closed() {
return Err(TrySendError::NoReceiver(val));
}
let idx =
if let Some(idx) = critical_section::with(|cs| self.0.access(cs).freeq.pop_front()) {
idx
} else {
return Err(TrySendError::Full(val));
};
self.send_footer(idx, val);
Ok(())
}
pub async fn send(&mut self, val: T) -> Result<(), NoReceiver<T>> {
let mut link_ptr: Option<Link<Waker>> = None;
let mut link_ptr = LinkPtr(&mut link_ptr as *mut Option<Link<Waker>>);
let mut link_ptr2 = link_ptr.clone();
let dropper = OnDrop::new(|| {
if let Some(link) = unsafe { link_ptr2.get() } {
link.remove_from_list(&self.0.wait_queue);
}
});
let idx = poll_fn(|cx| {
if self.is_closed() {
return Poll::Ready(Err(()));
}
let queue_idx = critical_section::with(|cs| {
let wq_empty = self.0.wait_queue.is_empty();
let fq_empty = self.0.access(cs).freeq.is_empty();
if !wq_empty || fq_empty {
let link = unsafe { link_ptr.get() };
if let Some(link) = link {
if !link.is_popped() {
return None;
} else {
}
} else {
let link_ref = link.insert(Link::new(cx.waker().clone()));
unsafe { self.0.wait_queue.push(Pin::new_unchecked(link_ref)) };
return None;
}
}
assert!(!self.0.access(cs).freeq.is_empty());
let idx = unsafe { self.0.access(cs).freeq.pop_front_unchecked() };
Some(idx)
});
if let Some(idx) = queue_idx {
Poll::Ready(Ok(idx))
} else {
Poll::Pending
}
})
.await;
drop(dropper);
if let Ok(idx) = idx {
self.send_footer(idx, val);
Ok(())
} else {
Err(NoReceiver(val))
}
}
pub fn is_closed(&self) -> bool {
critical_section::with(|cs| *self.0.access(cs).receiver_dropped)
}
pub fn is_full(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).freeq.is_empty())
}
pub fn is_empty(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).freeq.is_full())
}
}
impl<'a, T, const N: usize> Drop for Sender<'a, T, N> {
fn drop(&mut self) {
let num_senders = critical_section::with(|cs| {
*self.0.access(cs).num_senders -= 1;
*self.0.access(cs).num_senders
});
if num_senders == 0 {
self.0.receiver_waker.wake();
}
}
}
impl<'a, T, const N: usize> Clone for Sender<'a, T, N> {
fn clone(&self) -> Self {
critical_section::with(|cs| *self.0.access(cs).num_senders += 1);
Self(self.0)
}
}
pub struct Receiver<'a, T, const N: usize>(&'a Channel<T, N>);
unsafe impl<'a, T, const N: usize> Send for Receiver<'a, T, N> {}
impl<'a, T, const N: usize> core::fmt::Debug for Receiver<'a, T, N> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Receiver")
}
}
#[cfg(feature = "defmt-03")]
impl<'a, T, const N: usize> defmt::Format for Receiver<'a, T, N> {
fn format(&self, f: defmt::Formatter) {
defmt::write!(f, "Receiver",)
}
}
#[cfg_attr(feature = "defmt-03", derive(defmt::Format))]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum ReceiveError {
NoSender,
Empty,
}
impl<'a, T, const N: usize> Receiver<'a, T, N> {
pub fn try_recv(&mut self) -> Result<T, ReceiveError> {
let ready_slot = critical_section::with(|cs| self.0.access(cs).readyq.pop_front());
if let Some(rs) = ready_slot {
let r = unsafe { ptr::read(self.0.slots.get_unchecked(rs as usize).get() as *const T) };
critical_section::with(|cs| {
assert!(!self.0.access(cs).freeq.is_full());
unsafe { self.0.access(cs).freeq.push_back_unchecked(rs) }
});
fence(Ordering::SeqCst);
if let Some(wait_head) = self.0.wait_queue.pop() {
wait_head.wake();
}
Ok(r)
} else if self.is_closed() {
Err(ReceiveError::NoSender)
} else {
Err(ReceiveError::Empty)
}
}
pub async fn recv(&mut self) -> Result<T, ReceiveError> {
poll_fn(|cx| {
self.0.receiver_waker.register(cx.waker());
match self.try_recv() {
Ok(val) => {
return Poll::Ready(Ok(val));
}
Err(ReceiveError::NoSender) => {
return Poll::Ready(Err(ReceiveError::NoSender));
}
_ => {}
}
Poll::Pending
})
.await
}
pub fn is_closed(&self) -> bool {
critical_section::with(|cs| *self.0.access(cs).num_senders == 0)
}
pub fn is_full(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).readyq.is_full())
}
pub fn is_empty(&self) -> bool {
critical_section::with(|cs| self.0.access(cs).readyq.is_empty())
}
}
impl<'a, T, const N: usize> Drop for Receiver<'a, T, N> {
fn drop(&mut self) {
critical_section::with(|cs| *self.0.access(cs).receiver_dropped = true);
while let Some(waker) = self.0.wait_queue.pop() {
waker.wake();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty() {
let (mut s, mut r) = make_channel!(u32, 10);
assert!(s.is_empty());
assert!(r.is_empty());
s.try_send(1).unwrap();
assert!(!s.is_empty());
assert!(!r.is_empty());
r.try_recv().unwrap();
assert!(s.is_empty());
assert!(r.is_empty());
}
#[test]
fn full() {
let (mut s, mut r) = make_channel!(u32, 3);
for _ in 0..3 {
assert!(!s.is_full());
assert!(!r.is_full());
s.try_send(1).unwrap();
}
assert!(s.is_full());
assert!(r.is_full());
for _ in 0..3 {
r.try_recv().unwrap();
assert!(!s.is_full());
assert!(!r.is_full());
}
}
#[test]
fn send_recieve() {
let (mut s, mut r) = make_channel!(u32, 10);
for i in 0..10 {
s.try_send(i).unwrap();
}
assert_eq!(s.try_send(11), Err(TrySendError::Full(11)));
for i in 0..10 {
assert_eq!(r.try_recv().unwrap(), i);
}
assert_eq!(r.try_recv(), Err(ReceiveError::Empty));
}
#[test]
fn closed_recv() {
let (s, mut r) = make_channel!(u32, 10);
drop(s);
assert!(r.is_closed());
assert_eq!(r.try_recv(), Err(ReceiveError::NoSender));
}
#[test]
fn closed_sender() {
let (mut s, r) = make_channel!(u32, 10);
drop(r);
assert!(s.is_closed());
assert_eq!(s.try_send(11), Err(TrySendError::NoReceiver(11)));
}
#[tokio::test]
async fn stress_channel() {
const NUM_RUNS: usize = 1_000;
const QUEUE_SIZE: usize = 10;
let (s, mut r) = make_channel!(u32, QUEUE_SIZE);
let mut v = std::vec::Vec::new();
for i in 0..NUM_RUNS {
let mut s = s.clone();
v.push(tokio::spawn(async move {
s.send(i as _).await.unwrap();
}));
}
let mut map = std::collections::BTreeSet::new();
for _ in 0..NUM_RUNS {
map.insert(r.recv().await.unwrap());
}
assert_eq!(map.len(), NUM_RUNS);
for v in v {
v.await.unwrap();
}
}
fn make() {
let _ = make_channel!(u32, 10);
}
#[test]
#[should_panic]
fn double_make_channel() {
make();
make();
}
#[test]
fn tuple_channel() {
let _ = make_channel!((i32, u32), 10);
}
}