// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly use std::{ collections::HashMap, ops::{Deref, DerefMut}, sync::Arc, time::Duration, }; use jid::JID; use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream}; use read::{ReadControl, ReadControlHandle, ReadState}; use stanza::client::Stanza; use tokio::{ sync::{mpsc, oneshot, Mutex}, task::{JoinHandle, JoinSet}, }; use tracing::info; use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage, WriteState}; use crate::{ error::{ConnectionError, WriteError}, Connected, Logic, }; mod read; pub(crate) mod write; pub struct Supervisor { command_recv: mpsc::Receiver, reader_crash: oneshot::Receiver, writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>, read_control_handle: ReadControlHandle, write_control_handle: WriteControlHandle, on_crash: oneshot::Sender<()>, // jid in connected stays the same over the life of the supervisor (the connection session) connected: Connected, password: Arc, logic: Lgc, } pub enum SupervisorCommand { Disconnect, // for if there was a stream error, require to reconnect // couldn't stream errors just cause a crash? lol Reconnect(ChildState), } pub enum ChildState { Write(WriteState), Read(ReadState), } impl Supervisor { fn new( command_recv: mpsc::Receiver, reader_crash: oneshot::Receiver, writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>, read_control_handle: ReadControlHandle, write_control_handle: WriteControlHandle, on_crash: oneshot::Sender<()>, connected: Connected, password: Arc, logic: Lgc, ) -> Self { Self { command_recv, reader_crash, writer_crash, read_control_handle, write_control_handle, on_crash, connected, password, logic, } } async fn run(mut self) { loop { tokio::select! { Some(msg) = self.command_recv.recv() => { match msg { SupervisorCommand::Disconnect => { info!("disconnecting"); self.logic .handle_disconnect(self.connected.clone()) .await; let _ = self.write_control_handle.send(WriteControl::Disconnect).await; let _ = self.read_control_handle.send(ReadControl::Disconnect).await; info!("sent disconnect command"); tokio::select! { _ = async { tokio::join!( async { let _ = (&mut self.write_control_handle.handle).await; }, async { let _ = (&mut self.read_control_handle.handle).await; } ) } => {}, // TODO: config timeout _ = async { tokio::time::sleep(Duration::from_secs(5)) } => { (&mut self.read_control_handle.handle).abort(); (&mut self.write_control_handle.handle).abort(); } } info!("disconnected"); break; }, // TODO: Reconnect without aborting, gentle reconnect. SupervisorCommand::Reconnect(state) => { // TODO: please omfg // send abort to read stream, as already done, consider let (read_state, mut write_state); match state { ChildState::Write(receiver) => { write_state = receiver; let (send, recv) = oneshot::channel(); let _ = self.read_control_handle.send(ReadControl::Abort(send)).await; // TODO: need a tokio select, in case the state arrives from somewhere else if let Ok(state) = recv.await { read_state = state; } else { break } }, ChildState::Read(read) => { read_state = read; let (send, recv) = oneshot::channel(); let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; // TODO: need a tokio select, in case the state arrives from somewhere else if let Ok(state) = recv.await { write_state = state; } else { break } }, } let mut jid = self.connected.jid.clone(); let mut domain = jid.domainpart.clone(); // TODO: make sure connect_and_login does not modify the jid, but instead returns a jid. or something like that let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; match connection { Ok(c) => { let (read, write) = c.split(); let (send, recv) = oneshot::channel(); self.writer_crash = recv; self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv); let (send, recv) = oneshot::channel(); self.reader_crash = recv; self.read_control_handle = ReadControlHandle::reconnect( read, read_state.tasks, self.connected.clone(), self.logic.clone(), read_state.supervisor_control, send, ); }, Err(e) => { // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. write_state.stanza_recv.close(); while let Some(msg) = write_state.stanza_recv.recv().await { let _ = msg.respond_to.send(Err(WriteError::LostConnection)); } // TODO: is this the correct error? self.logic.handle_connection_error(ConnectionError::LostConnection).await; break; }, } }, } }, Ok((write_msg, mut write_state)) = &mut self.writer_crash => { // consider awaiting/aborting the read and write threads let (send, recv) = oneshot::channel(); let _ = self.read_control_handle.send(ReadControl::Abort(send)).await; let read_state = tokio::select! { Ok(s) = recv => s, Ok(s) = &mut self.reader_crash => s, // in case, just break as irrecoverable else => break, }; let mut jid = self.connected.jid.clone(); let mut domain = jid.domainpart.clone(); // TODO: same here let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; match connection { Ok(c) => { let (read, write) = c.split(); let (send, recv) = oneshot::channel(); self.writer_crash = recv; self.write_control_handle = WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, write_msg); let (send, recv) = oneshot::channel(); self.reader_crash = recv; self.read_control_handle = ReadControlHandle::reconnect( read, read_state.tasks, self.connected.clone(), self.logic.clone(), read_state.supervisor_control, send, ); }, Err(e) => { // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. write_state.stanza_recv.close(); let _ = write_msg.respond_to.send(Err(WriteError::LostConnection)); while let Some(msg) = write_state.stanza_recv.recv().await { let _ = msg.respond_to.send(Err(WriteError::LostConnection)); } // TODO: is this the correct error to send? self.logic.handle_connection_error(ConnectionError::LostConnection).await; break; }, } }, Ok(read_state) = &mut self.reader_crash => { let (send, recv) = oneshot::channel(); let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; let (retry_msg, mut write_state) = tokio::select! { Ok(s) = recv => (None, s), Ok(s) = &mut self.writer_crash => (Some(s.0), s.1), // in case, just break as irrecoverable else => break, }; let mut jid = self.connected.jid.clone(); let mut domain = jid.domainpart.clone(); let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; match connection { Ok(c) => { let (read, write) = c.split(); let (send, recv) = oneshot::channel(); self.writer_crash = recv; if let Some(msg) = retry_msg { self.write_control_handle = WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, msg); } else { self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv) } let (send, recv) = oneshot::channel(); self.reader_crash = recv; self.read_control_handle = ReadControlHandle::reconnect( read, read_state.tasks, self.connected.clone(), self.logic.clone(), read_state.supervisor_control, send, ); }, Err(e) => { // if reconnection failure, respond to all current messages with lost connection error. write_state.stanza_recv.close(); if let Some(msg) = retry_msg { msg.respond_to.send(Err(WriteError::LostConnection)); } while let Some(msg) = write_state.stanza_recv.recv().await { msg.respond_to.send(Err(WriteError::LostConnection)); } // TODO: is this the correct error? self.logic.handle_connection_error(ConnectionError::LostConnection).await; break; }, } }, else => break, } } // TODO: maybe don't just on_crash let _ = self.on_crash.send(()); } } pub struct SupervisorHandle { sender: SupervisorSender, handle: JoinHandle<()>, } impl Deref for SupervisorHandle { type Target = SupervisorSender; fn deref(&self) -> &Self::Target { &self.sender } } impl DerefMut for SupervisorHandle { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.sender } } #[derive(Clone)] pub struct SupervisorSender { sender: mpsc::Sender, } impl Deref for SupervisorSender { type Target = mpsc::Sender; fn deref(&self) -> &Self::Target { &self.sender } } impl DerefMut for SupervisorSender { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.sender } } impl SupervisorHandle { pub fn new( streams: BoundJabberStream, on_crash: oneshot::Sender<()>, jid: JID, password: Arc, logic: Lgc, ) -> (WriteHandle, Self) { let (command_send, command_recv) = mpsc::channel(20); let (writer_crash_send, writer_crash_recv) = oneshot::channel(); let (reader_crash_send, reader_crash_recv) = oneshot::channel(); let (read_stream, write_stream) = streams.split(); let (write_handle, write_control_handle) = WriteControlHandle::new(write_stream, writer_crash_send); let connected = Connected { jid, write_handle: write_handle.clone(), }; let supervisor_sender = SupervisorSender { sender: command_send, }; let read_control_handle = ReadControlHandle::new( read_stream, connected.clone(), logic.clone(), supervisor_sender.clone(), reader_crash_send, ); let actor = Supervisor::new( command_recv, reader_crash_recv, writer_crash_recv, read_control_handle, write_control_handle, on_crash, connected, password, logic, ); let handle = tokio::spawn(async move { actor.run().await }); ( write_handle, Self { sender: supervisor_sender, handle, }, ) } pub fn sender(&self) -> SupervisorSender { self.sender.clone() } }