use std::ops::{Deref, DerefMut};
use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberWriter};
use stanza::client::Stanza;
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use crate::error::WriteError;
/// actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream.
pub struct Write {
stream: BoundJabberWriter<Tls>,
/// connection session write queue
stanza_receiver: mpsc::Receiver<WriteMessage>,
// control stuff
control_receiver: mpsc::Receiver<WriteControl>,
on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
}
/// when a crash/abort occurs, this gets sent back to the supervisor, possibly with the current write that failed, so that the connection session can continue
pub struct WriteState {
pub stanza_recv: mpsc::Receiver<WriteMessage>,
}
#[derive(Debug)]
pub struct WriteMessage {
pub stanza: Stanza,
pub respond_to: oneshot::Sender<Result<(), WriteError>>,
}
pub enum WriteControl {
Disconnect,
Abort(oneshot::Sender<WriteState>),
}
impl Write {
fn new(
stream: BoundJabberWriter<Tls>,
stanza_receiver: mpsc::Receiver<WriteMessage>,
control_receiver: mpsc::Receiver<WriteControl>,
on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
) -> Self {
Self {
stream,
stanza_receiver,
control_receiver,
on_crash,
}
}
async fn write(&mut self, stanza: &Stanza) -> Result<(), peanuts::Error> {
Ok(self.stream.write(stanza).await?)
}
async fn run_reconnected(mut self, retry_msg: WriteMessage) {
// try to retry sending the message that failed to send previously
let result = self.stream.write(&retry_msg.stanza).await;
match result {
Err(e) => match &e {
peanuts::Error::ReadError(_error) => {
// make sure message is not lost from error, supervisor handles retry and reporting
// TODO: upon reconnect, make sure we are not stuck in a reconnection loop
let _ = self.on_crash.send((
retry_msg,
WriteState {
stanza_recv: self.stanza_receiver,
},
));
return;
}
_ => {
let _ = retry_msg.respond_to.send(Err(e.into()));
}
},
_ => {
let _ = retry_msg.respond_to.send(Ok(()));
}
}
// return to normal loop
self.run().await
}
async fn run(mut self) {
loop {
tokio::select! {
Some(msg) = self.control_receiver.recv() => {
match msg {
WriteControl::Disconnect => {
// close the stanza_receiver channel and drain out all of the remaining stanzas to send
self.stanza_receiver.close();
// TODO: put this in some kind of function to avoid code duplication
while let Some(msg) = self.stanza_receiver.recv().await {
let result = self.stream.write(&msg.stanza).await;
match result {
Err(e) => match &e {
peanuts::Error::ReadError(_error) => {
// if connection lost during disconnection, just send lost connection error to the write requests
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
while let Some(msg) = self.stanza_receiver.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
}
break;
}
// otherwise complete sending all the stanzas currently in the queue
_ => {
let _ = msg.respond_to.send(Err(e.into()));
}
},
_ => {
let _ = msg.respond_to.send(Ok(()));
}
}
}
let _ = self.stream.try_close().await;
break;
},
// in case of abort, stream is already fucked, just send the receiver ready for a reconnection at the same resource
WriteControl::Abort(sender) => {
let _ = sender.send(WriteState { stanza_recv: self.stanza_receiver });
break;
},
}
},
Some(msg) = self.stanza_receiver.recv() => {
let result = self.stream.write(&msg.stanza).await;
match result {
Err(e) => match &e {
peanuts::Error::ReadError(_error) => {
// make sure message is not lost from error, supervisor handles retry and reporting
let _ = self.on_crash.send((msg, WriteState { stanza_recv: self.stanza_receiver }));
break;
}
_ => {
let _ = msg.respond_to.send(Err(e.into()));
}
},
_ => {
let _ = msg.respond_to.send(Ok(()));
}
}
},
else => break,
}
}
}
}
#[derive(Clone)]
pub struct WriteHandle {
sender: mpsc::Sender<WriteMessage>,
}
impl WriteHandle {
pub async fn write(&self, stanza: Stanza) -> Result<(), WriteError> {
let (send, recv) = oneshot::channel();
self.send(WriteMessage {
stanza,
respond_to: send,
})
.await
.map_err(|e| WriteError::Actor(e.into()))?;
// TODO: timeout
recv.await.map_err(|e| WriteError::Actor(e.into()))?
}
}
impl Deref for WriteHandle {
type Target = mpsc::Sender<WriteMessage>;
fn deref(&self) -> &Self::Target {
&self.sender
}
}
impl DerefMut for WriteHandle {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.sender
}
}
pub struct WriteControlHandle {
sender: mpsc::Sender<WriteControl>,
pub(crate) handle: JoinHandle<()>,
}
impl Deref for WriteControlHandle {
type Target = mpsc::Sender<WriteControl>;
fn deref(&self) -> &Self::Target {
&self.sender
}
}
impl DerefMut for WriteControlHandle {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.sender
}
}
impl WriteControlHandle {
pub fn new(
stream: BoundJabberWriter<Tls>,
on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
) -> (WriteHandle, Self) {
let (control_sender, control_receiver) = mpsc::channel(20);
let (stanza_sender, stanza_receiver) = mpsc::channel(20);
let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
let handle = tokio::spawn(async move { actor.run().await });
(
WriteHandle {
sender: stanza_sender,
},
Self {
sender: control_sender,
handle,
},
)
}
pub fn reconnect_retry(
stream: BoundJabberWriter<Tls>,
on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
stanza_receiver: mpsc::Receiver<WriteMessage>,
retry_msg: WriteMessage,
) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20);
let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
let handle = tokio::spawn(async move { actor.run_reconnected(retry_msg).await });
Self {
sender: control_sender,
handle,
}
}
pub fn reconnect(
stream: BoundJabberWriter<Tls>,
on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
stanza_receiver: mpsc::Receiver<WriteMessage>,
) -> Self {
let (control_sender, control_receiver) = mpsc::channel(20);
let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
let handle = tokio::spawn(async move { actor.run().await });
Self {
sender: control_sender,
handle,
}
}
}