// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::Arc,
time::Duration,
};
use jid::JID;
use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream};
use read::{ReadControl, ReadControlHandle, ReadState};
use stanza::client::Stanza;
use tokio::{
sync::{mpsc, oneshot, Mutex},
task::{JoinHandle, JoinSet},
};
use tracing::info;
use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage, WriteState};
use crate::{
error::{ConnectionError, WriteError},
Connected, Logic,
};
mod read;
pub(crate) mod write;
pub struct Supervisor<Lgc> {
command_recv: mpsc::Receiver<SupervisorCommand>,
reader_crash: oneshot::Receiver<ReadState>,
writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
read_control_handle: ReadControlHandle,
write_control_handle: WriteControlHandle,
on_crash: oneshot::Sender<()>,
// jid in connected stays the same over the life of the supervisor (the connection session)
connected: Connected,
password: Arc<String>,
logic: Lgc,
}
pub enum SupervisorCommand {
Disconnect,
// for if there was a stream error, require to reconnect
// couldn't stream errors just cause a crash? lol
Reconnect(ChildState),
}
pub enum ChildState {
Write(WriteState),
Read(ReadState),
}
impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {
fn new(
command_recv: mpsc::Receiver<SupervisorCommand>,
reader_crash: oneshot::Receiver<ReadState>,
writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
read_control_handle: ReadControlHandle,
write_control_handle: WriteControlHandle,
on_crash: oneshot::Sender<()>,
connected: Connected,
password: Arc<String>,
logic: Lgc,
) -> Self {
Self {
command_recv,
reader_crash,
writer_crash,
read_control_handle,
write_control_handle,
on_crash,
connected,
password,
logic,
}
}
async fn run(mut self) {
loop {
tokio::select! {
Some(msg) = self.command_recv.recv() => {
match msg {
SupervisorCommand::Disconnect => {
info!("disconnecting");
self.logic
.handle_disconnect(self.connected.clone())
.await;
let _ = self.write_control_handle.send(WriteControl::Disconnect).await;
let _ = self.read_control_handle.send(ReadControl::Disconnect).await;
info!("sent disconnect command");
tokio::select! {
_ = async { tokio::join!(
async { let _ = (&mut self.write_control_handle.handle).await; },
async { let _ = (&mut self.read_control_handle.handle).await; }
) } => {},
// TODO: config timeout
_ = async { tokio::time::sleep(Duration::from_secs(5)) } => {
(&mut self.read_control_handle.handle).abort();
(&mut self.write_control_handle.handle).abort();
}
}
info!("disconnected");
break;
},
// TODO: Reconnect without aborting, gentle reconnect.
SupervisorCommand::Reconnect(state) => {
// TODO: please omfg
// send abort to read stream, as already done, consider
let (read_state, mut write_state);
match state {
ChildState::Write(receiver) => {
write_state = receiver;
let (send, recv) = oneshot::channel();
let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
// TODO: need a tokio select, in case the state arrives from somewhere else
if let Ok(state) = recv.await {
read_state = state;
} else {
break
}
},
ChildState::Read(read) => {
read_state = read;
let (send, recv) = oneshot::channel();
let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
// TODO: need a tokio select, in case the state arrives from somewhere else
if let Ok(state) = recv.await {
write_state = state;
} else {
break
}
},
}
let mut jid = self.connected.jid.clone();
let mut domain = jid.domainpart.clone();
// TODO: make sure connect_and_login does not modify the jid, but instead returns a jid. or something like that
let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
match connection {
Ok(c) => {
let (read, write) = c.split();
let (send, recv) = oneshot::channel();
self.writer_crash = recv;
self.write_control_handle =
WriteControlHandle::reconnect(write, send, write_state.stanza_recv);
let (send, recv) = oneshot::channel();
self.reader_crash = recv;
self.read_control_handle = ReadControlHandle::reconnect(
read,
read_state.tasks,
self.connected.clone(),
self.logic.clone(),
read_state.supervisor_control,
send,
);
},
Err(e) => {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_state.stanza_recv.close();
while let Some(msg) = write_state.stanza_recv.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
}
// TODO: is this the correct error?
self.logic.handle_connection_error(ConnectionError::LostConnection).await;
break;
},
}
},
}
},
Ok((write_msg, mut write_state)) = &mut self.writer_crash => {
// consider awaiting/aborting the read and write threads
let (send, recv) = oneshot::channel();
let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
let read_state = tokio::select! {
Ok(s) = recv => s,
Ok(s) = &mut self.reader_crash => s,
// in case, just break as irrecoverable
else => break,
};
let mut jid = self.connected.jid.clone();
let mut domain = jid.domainpart.clone();
// TODO: same here
let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
match connection {
Ok(c) => {
let (read, write) = c.split();
let (send, recv) = oneshot::channel();
self.writer_crash = recv;
self.write_control_handle =
WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, write_msg);
let (send, recv) = oneshot::channel();
self.reader_crash = recv;
self.read_control_handle = ReadControlHandle::reconnect(
read,
read_state.tasks,
self.connected.clone(),
self.logic.clone(),
read_state.supervisor_control,
send,
);
},
Err(e) => {
// if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
write_state.stanza_recv.close();
let _ = write_msg.respond_to.send(Err(WriteError::LostConnection));
while let Some(msg) = write_state.stanza_recv.recv().await {
let _ = msg.respond_to.send(Err(WriteError::LostConnection));
}
// TODO: is this the correct error to send?
self.logic.handle_connection_error(ConnectionError::LostConnection).await;
break;
},
}
},
Ok(read_state) = &mut self.reader_crash => {
let (send, recv) = oneshot::channel();
let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
let (retry_msg, mut write_state) = tokio::select! {
Ok(s) = recv => (None, s),
Ok(s) = &mut self.writer_crash => (Some(s.0), s.1),
// in case, just break as irrecoverable
else => break,
};
let mut jid = self.connected.jid.clone();
let mut domain = jid.domainpart.clone();
let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
match connection {
Ok(c) => {
let (read, write) = c.split();
let (send, recv) = oneshot::channel();
self.writer_crash = recv;
if let Some(msg) = retry_msg {
self.write_control_handle =
WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, msg);
} else {
self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv)
}
let (send, recv) = oneshot::channel();
self.reader_crash = recv;
self.read_control_handle = ReadControlHandle::reconnect(
read,
read_state.tasks,
self.connected.clone(),
self.logic.clone(),
read_state.supervisor_control,
send,
);
},
Err(e) => {
// if reconnection failure, respond to all current messages with lost connection error.
write_state.stanza_recv.close();
if let Some(msg) = retry_msg {
msg.respond_to.send(Err(WriteError::LostConnection));
}
while let Some(msg) = write_state.stanza_recv.recv().await {
msg.respond_to.send(Err(WriteError::LostConnection));
}
// TODO: is this the correct error?
self.logic.handle_connection_error(ConnectionError::LostConnection).await;
break;
},
}
},
else => break,
}
}
// TODO: maybe don't just on_crash
let _ = self.on_crash.send(());
}
}
pub struct SupervisorHandle {
sender: SupervisorSender,
handle: JoinHandle<()>,
}
impl Deref for SupervisorHandle {
type Target = SupervisorSender;
fn deref(&self) -> &Self::Target {
&self.sender
}
}
impl DerefMut for SupervisorHandle {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.sender
}
}
#[derive(Clone)]
pub struct SupervisorSender {
sender: mpsc::Sender<SupervisorCommand>,
}
impl Deref for SupervisorSender {
type Target = mpsc::Sender<SupervisorCommand>;
fn deref(&self) -> &Self::Target {
&self.sender
}
}
impl DerefMut for SupervisorSender {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.sender
}
}
impl SupervisorHandle {
pub fn new<Lgc: Logic + Clone + Send + 'static>(
streams: BoundJabberStream<Tls>,
on_crash: oneshot::Sender<()>,
jid: JID,
password: Arc<String>,
logic: Lgc,
) -> (WriteHandle, Self) {
let (command_send, command_recv) = mpsc::channel(20);
let (writer_crash_send, writer_crash_recv) = oneshot::channel();
let (reader_crash_send, reader_crash_recv) = oneshot::channel();
let (read_stream, write_stream) = streams.split();
let (write_handle, write_control_handle) =
WriteControlHandle::new(write_stream, writer_crash_send);
let connected = Connected {
jid,
write_handle: write_handle.clone(),
};
let supervisor_sender = SupervisorSender {
sender: command_send,
};
let read_control_handle = ReadControlHandle::new(
read_stream,
connected.clone(),
logic.clone(),
supervisor_sender.clone(),
reader_crash_send,
);
let actor = Supervisor::new(
command_recv,
reader_crash_recv,
writer_crash_recv,
read_control_handle,
write_control_handle,
on_crash,
connected,
password,
logic,
);
let handle = tokio::spawn(async move { actor.run().await });
(
write_handle,
Self {
sender: supervisor_sender,
handle,
},
)
}
pub fn sender(&self) -> SupervisorSender {
self.sender.clone()
}
}