diff options
Diffstat (limited to 'lampada/src/connection/mod.rs')
-rw-r--r-- | lampada/src/connection/mod.rs | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/lampada/src/connection/mod.rs b/lampada/src/connection/mod.rs new file mode 100644 index 0000000..1e767b0 --- /dev/null +++ b/lampada/src/connection/mod.rs @@ -0,0 +1,374 @@ +// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly + +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, + sync::Arc, + time::Duration, +}; + +use jid::JID; +use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream}; +use read::{ReadControl, ReadControlHandle, ReadState}; +use stanza::client::Stanza; +use tokio::{ + sync::{mpsc, oneshot, Mutex}, + task::{JoinHandle, JoinSet}, +}; +use tracing::info; +use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage, WriteState}; + +use crate::{ + error::{ConnectionError, WriteError}, + Connected, Logic, +}; + +mod read; +pub(crate) mod write; + +pub struct Supervisor<Lgc> { + command_recv: mpsc::Receiver<SupervisorCommand>, + reader_crash: oneshot::Receiver<ReadState>, + writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>, + read_control_handle: ReadControlHandle, + write_control_handle: WriteControlHandle, + on_crash: oneshot::Sender<()>, + // jid in connected stays the same over the life of the supervisor (the connection session) + connected: Connected, + password: Arc<String>, + logic: Lgc, +} + +pub enum SupervisorCommand { + Disconnect, + // for if there was a stream error, require to reconnect + // couldn't stream errors just cause a crash? lol + Reconnect(ChildState), +} + +pub enum ChildState { + Write(WriteState), + Read(ReadState), +} + +impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> { + fn new( + command_recv: mpsc::Receiver<SupervisorCommand>, + reader_crash: oneshot::Receiver<ReadState>, + writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>, + read_control_handle: ReadControlHandle, + write_control_handle: WriteControlHandle, + on_crash: oneshot::Sender<()>, + connected: Connected, + password: Arc<String>, + logic: Lgc, + ) -> Self { + Self { + command_recv, + reader_crash, + writer_crash, + read_control_handle, + write_control_handle, + on_crash, + connected, + password, + logic, + } + } + + async fn run(mut self) { + loop { + tokio::select! { + Some(msg) = self.command_recv.recv() => { + match msg { + SupervisorCommand::Disconnect => { + info!("disconnecting"); + self.logic + .handle_disconnect(self.connected.clone()) + .await; + let _ = self.write_control_handle.send(WriteControl::Disconnect).await; + let _ = self.read_control_handle.send(ReadControl::Disconnect).await; + info!("sent disconnect command"); + tokio::select! { + _ = async { tokio::join!( + async { let _ = (&mut self.write_control_handle.handle).await; }, + async { let _ = (&mut self.read_control_handle.handle).await; } + ) } => {}, + // TODO: config timeout + _ = async { tokio::time::sleep(Duration::from_secs(5)) } => { + (&mut self.read_control_handle.handle).abort(); + (&mut self.write_control_handle.handle).abort(); + } + } + info!("disconnected"); + break; + }, + // TODO: Reconnect without aborting, gentle reconnect. + SupervisorCommand::Reconnect(state) => { + // TODO: please omfg + // send abort to read stream, as already done, consider + let (read_state, mut write_state); + match state { + ChildState::Write(receiver) => { + write_state = receiver; + let (send, recv) = oneshot::channel(); + let _ = self.read_control_handle.send(ReadControl::Abort(send)).await; + // TODO: need a tokio select, in case the state arrives from somewhere else + if let Ok(state) = recv.await { + read_state = state; + } else { + break + } + }, + ChildState::Read(read) => { + read_state = read; + let (send, recv) = oneshot::channel(); + let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; + // TODO: need a tokio select, in case the state arrives from somewhere else + if let Ok(state) = recv.await { + write_state = state; + } else { + break + } + }, + } + + let mut jid = self.connected.jid.clone(); + let mut domain = jid.domainpart.clone(); + // TODO: make sure connect_and_login does not modify the jid, but instead returns a jid. or something like that + let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; + match connection { + Ok(c) => { + let (read, write) = c.split(); + let (send, recv) = oneshot::channel(); + self.writer_crash = recv; + self.write_control_handle = + WriteControlHandle::reconnect(write, send, write_state.stanza_recv); + let (send, recv) = oneshot::channel(); + self.reader_crash = recv; + self.read_control_handle = ReadControlHandle::reconnect( + read, + read_state.tasks, + self.connected.clone(), + self.logic.clone(), + read_state.supervisor_control, + send, + ); + }, + Err(e) => { + // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. + write_state.stanza_recv.close(); + while let Some(msg) = write_state.stanza_recv.recv().await { + let _ = msg.respond_to.send(Err(WriteError::LostConnection)); + } + // TODO: is this the correct error? + self.logic.handle_connection_error(ConnectionError::LostConnection).await; + break; + }, + } + }, + } + }, + Ok((write_msg, mut write_state)) = &mut self.writer_crash => { + // consider awaiting/aborting the read and write threads + let (send, recv) = oneshot::channel(); + let _ = self.read_control_handle.send(ReadControl::Abort(send)).await; + let read_state = tokio::select! { + Ok(s) = recv => s, + Ok(s) = &mut self.reader_crash => s, + // in case, just break as irrecoverable + else => break, + }; + + let mut jid = self.connected.jid.clone(); + let mut domain = jid.domainpart.clone(); + // TODO: same here + let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; + match connection { + Ok(c) => { + let (read, write) = c.split(); + let (send, recv) = oneshot::channel(); + self.writer_crash = recv; + self.write_control_handle = + WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, write_msg); + let (send, recv) = oneshot::channel(); + self.reader_crash = recv; + self.read_control_handle = ReadControlHandle::reconnect( + read, + read_state.tasks, + self.connected.clone(), + self.logic.clone(), + read_state.supervisor_control, + send, + ); + }, + Err(e) => { + // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves. + write_state.stanza_recv.close(); + let _ = write_msg.respond_to.send(Err(WriteError::LostConnection)); + while let Some(msg) = write_state.stanza_recv.recv().await { + let _ = msg.respond_to.send(Err(WriteError::LostConnection)); + } + // TODO: is this the correct error to send? + self.logic.handle_connection_error(ConnectionError::LostConnection).await; + break; + }, + } + }, + Ok(read_state) = &mut self.reader_crash => { + let (send, recv) = oneshot::channel(); + let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; + let (retry_msg, mut write_state) = tokio::select! { + Ok(s) = recv => (None, s), + Ok(s) = &mut self.writer_crash => (Some(s.0), s.1), + // in case, just break as irrecoverable + else => break, + }; + + let mut jid = self.connected.jid.clone(); + let mut domain = jid.domainpart.clone(); + let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await; + match connection { + Ok(c) => { + let (read, write) = c.split(); + let (send, recv) = oneshot::channel(); + self.writer_crash = recv; + if let Some(msg) = retry_msg { + self.write_control_handle = + WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, msg); + } else { + self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv) + } + let (send, recv) = oneshot::channel(); + self.reader_crash = recv; + self.read_control_handle = ReadControlHandle::reconnect( + read, + read_state.tasks, + self.connected.clone(), + self.logic.clone(), + read_state.supervisor_control, + send, + ); + }, + Err(e) => { + // if reconnection failure, respond to all current messages with lost connection error. + write_state.stanza_recv.close(); + if let Some(msg) = retry_msg { + msg.respond_to.send(Err(WriteError::LostConnection)); + } + while let Some(msg) = write_state.stanza_recv.recv().await { + msg.respond_to.send(Err(WriteError::LostConnection)); + } + // TODO: is this the correct error? + self.logic.handle_connection_error(ConnectionError::LostConnection).await; + break; + }, + } + }, + else => break, + } + } + // TODO: maybe don't just on_crash + let _ = self.on_crash.send(()); + } +} + +pub struct SupervisorHandle { + sender: SupervisorSender, + handle: JoinHandle<()>, +} + +impl Deref for SupervisorHandle { + type Target = SupervisorSender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for SupervisorHandle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } +} + +#[derive(Clone)] +pub struct SupervisorSender { + sender: mpsc::Sender<SupervisorCommand>, +} + +impl Deref for SupervisorSender { + type Target = mpsc::Sender<SupervisorCommand>; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} + +impl DerefMut for SupervisorSender { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.sender + } +} + +impl SupervisorHandle { + pub fn new<Lgc: Logic + Clone + Send + 'static>( + streams: BoundJabberStream<Tls>, + on_crash: oneshot::Sender<()>, + jid: JID, + password: Arc<String>, + logic: Lgc, + ) -> (WriteHandle, Self) { + let (command_send, command_recv) = mpsc::channel(20); + let (writer_crash_send, writer_crash_recv) = oneshot::channel(); + let (reader_crash_send, reader_crash_recv) = oneshot::channel(); + + let (read_stream, write_stream) = streams.split(); + + let (write_handle, write_control_handle) = + WriteControlHandle::new(write_stream, writer_crash_send); + + let connected = Connected { + jid, + write_handle: write_handle.clone(), + }; + + let supervisor_sender = SupervisorSender { + sender: command_send, + }; + + let read_control_handle = ReadControlHandle::new( + read_stream, + connected.clone(), + logic.clone(), + supervisor_sender.clone(), + reader_crash_send, + ); + + let actor = Supervisor::new( + command_recv, + reader_crash_recv, + writer_crash_recv, + read_control_handle, + write_control_handle, + on_crash, + connected, + password, + logic, + ); + + let handle = tokio::spawn(async move { actor.run().await }); + + ( + write_handle, + Self { + sender: supervisor_sender, + handle, + }, + ) + } + + pub fn sender(&self) -> SupervisorSender { + self.sender.clone() + } +} |