aboutsummaryrefslogtreecommitdiffstats
path: root/lampada/src/connection/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'lampada/src/connection/mod.rs')
-rw-r--r--lampada/src/connection/mod.rs374
1 files changed, 374 insertions, 0 deletions
diff --git a/lampada/src/connection/mod.rs b/lampada/src/connection/mod.rs
new file mode 100644
index 0000000..1e767b0
--- /dev/null
+++ b/lampada/src/connection/mod.rs
@@ -0,0 +1,374 @@
+// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly
+
+use std::{
+ collections::HashMap,
+ ops::{Deref, DerefMut},
+ sync::Arc,
+ time::Duration,
+};
+
+use jid::JID;
+use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream};
+use read::{ReadControl, ReadControlHandle, ReadState};
+use stanza::client::Stanza;
+use tokio::{
+ sync::{mpsc, oneshot, Mutex},
+ task::{JoinHandle, JoinSet},
+};
+use tracing::info;
+use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage, WriteState};
+
+use crate::{
+ error::{ConnectionError, WriteError},
+ Connected, Logic,
+};
+
+mod read;
+pub(crate) mod write;
+
+pub struct Supervisor<Lgc> {
+ command_recv: mpsc::Receiver<SupervisorCommand>,
+ reader_crash: oneshot::Receiver<ReadState>,
+ writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
+ read_control_handle: ReadControlHandle,
+ write_control_handle: WriteControlHandle,
+ on_crash: oneshot::Sender<()>,
+ // jid in connected stays the same over the life of the supervisor (the connection session)
+ connected: Connected,
+ password: Arc<String>,
+ logic: Lgc,
+}
+
+pub enum SupervisorCommand {
+ Disconnect,
+ // for if there was a stream error, require to reconnect
+ // couldn't stream errors just cause a crash? lol
+ Reconnect(ChildState),
+}
+
+pub enum ChildState {
+ Write(WriteState),
+ Read(ReadState),
+}
+
+impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {
+ fn new(
+ command_recv: mpsc::Receiver<SupervisorCommand>,
+ reader_crash: oneshot::Receiver<ReadState>,
+ writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
+ read_control_handle: ReadControlHandle,
+ write_control_handle: WriteControlHandle,
+ on_crash: oneshot::Sender<()>,
+ connected: Connected,
+ password: Arc<String>,
+ logic: Lgc,
+ ) -> Self {
+ Self {
+ command_recv,
+ reader_crash,
+ writer_crash,
+ read_control_handle,
+ write_control_handle,
+ on_crash,
+ connected,
+ password,
+ logic,
+ }
+ }
+
+ async fn run(mut self) {
+ loop {
+ tokio::select! {
+ Some(msg) = self.command_recv.recv() => {
+ match msg {
+ SupervisorCommand::Disconnect => {
+ info!("disconnecting");
+ self.logic
+ .handle_disconnect(self.connected.clone())
+ .await;
+ let _ = self.write_control_handle.send(WriteControl::Disconnect).await;
+ let _ = self.read_control_handle.send(ReadControl::Disconnect).await;
+ info!("sent disconnect command");
+ tokio::select! {
+ _ = async { tokio::join!(
+ async { let _ = (&mut self.write_control_handle.handle).await; },
+ async { let _ = (&mut self.read_control_handle.handle).await; }
+ ) } => {},
+ // TODO: config timeout
+ _ = async { tokio::time::sleep(Duration::from_secs(5)) } => {
+ (&mut self.read_control_handle.handle).abort();
+ (&mut self.write_control_handle.handle).abort();
+ }
+ }
+ info!("disconnected");
+ break;
+ },
+ // TODO: Reconnect without aborting, gentle reconnect.
+ SupervisorCommand::Reconnect(state) => {
+ // TODO: please omfg
+ // send abort to read stream, as already done, consider
+ let (read_state, mut write_state);
+ match state {
+ ChildState::Write(receiver) => {
+ write_state = receiver;
+ let (send, recv) = oneshot::channel();
+ let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
+ // TODO: need a tokio select, in case the state arrives from somewhere else
+ if let Ok(state) = recv.await {
+ read_state = state;
+ } else {
+ break
+ }
+ },
+ ChildState::Read(read) => {
+ read_state = read;
+ let (send, recv) = oneshot::channel();
+ let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
+ // TODO: need a tokio select, in case the state arrives from somewhere else
+ if let Ok(state) = recv.await {
+ write_state = state;
+ } else {
+ break
+ }
+ },
+ }
+
+ let mut jid = self.connected.jid.clone();
+ let mut domain = jid.domainpart.clone();
+ // TODO: make sure connect_and_login does not modify the jid, but instead returns a jid. or something like that
+ let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
+ match connection {
+ Ok(c) => {
+ let (read, write) = c.split();
+ let (send, recv) = oneshot::channel();
+ self.writer_crash = recv;
+ self.write_control_handle =
+ WriteControlHandle::reconnect(write, send, write_state.stanza_recv);
+ let (send, recv) = oneshot::channel();
+ self.reader_crash = recv;
+ self.read_control_handle = ReadControlHandle::reconnect(
+ read,
+ read_state.tasks,
+ self.connected.clone(),
+ self.logic.clone(),
+ read_state.supervisor_control,
+ send,
+ );
+ },
+ Err(e) => {
+ // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
+ write_state.stanza_recv.close();
+ while let Some(msg) = write_state.stanza_recv.recv().await {
+ let _ = msg.respond_to.send(Err(WriteError::LostConnection));
+ }
+ // TODO: is this the correct error?
+ self.logic.handle_connection_error(ConnectionError::LostConnection).await;
+ break;
+ },
+ }
+ },
+ }
+ },
+ Ok((write_msg, mut write_state)) = &mut self.writer_crash => {
+ // consider awaiting/aborting the read and write threads
+ let (send, recv) = oneshot::channel();
+ let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
+ let read_state = tokio::select! {
+ Ok(s) = recv => s,
+ Ok(s) = &mut self.reader_crash => s,
+ // in case, just break as irrecoverable
+ else => break,
+ };
+
+ let mut jid = self.connected.jid.clone();
+ let mut domain = jid.domainpart.clone();
+ // TODO: same here
+ let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
+ match connection {
+ Ok(c) => {
+ let (read, write) = c.split();
+ let (send, recv) = oneshot::channel();
+ self.writer_crash = recv;
+ self.write_control_handle =
+ WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, write_msg);
+ let (send, recv) = oneshot::channel();
+ self.reader_crash = recv;
+ self.read_control_handle = ReadControlHandle::reconnect(
+ read,
+ read_state.tasks,
+ self.connected.clone(),
+ self.logic.clone(),
+ read_state.supervisor_control,
+ send,
+ );
+ },
+ Err(e) => {
+ // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
+ write_state.stanza_recv.close();
+ let _ = write_msg.respond_to.send(Err(WriteError::LostConnection));
+ while let Some(msg) = write_state.stanza_recv.recv().await {
+ let _ = msg.respond_to.send(Err(WriteError::LostConnection));
+ }
+ // TODO: is this the correct error to send?
+ self.logic.handle_connection_error(ConnectionError::LostConnection).await;
+ break;
+ },
+ }
+ },
+ Ok(read_state) = &mut self.reader_crash => {
+ let (send, recv) = oneshot::channel();
+ let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
+ let (retry_msg, mut write_state) = tokio::select! {
+ Ok(s) = recv => (None, s),
+ Ok(s) = &mut self.writer_crash => (Some(s.0), s.1),
+ // in case, just break as irrecoverable
+ else => break,
+ };
+
+ let mut jid = self.connected.jid.clone();
+ let mut domain = jid.domainpart.clone();
+ let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
+ match connection {
+ Ok(c) => {
+ let (read, write) = c.split();
+ let (send, recv) = oneshot::channel();
+ self.writer_crash = recv;
+ if let Some(msg) = retry_msg {
+ self.write_control_handle =
+ WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, msg);
+ } else {
+ self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv)
+ }
+ let (send, recv) = oneshot::channel();
+ self.reader_crash = recv;
+ self.read_control_handle = ReadControlHandle::reconnect(
+ read,
+ read_state.tasks,
+ self.connected.clone(),
+ self.logic.clone(),
+ read_state.supervisor_control,
+ send,
+ );
+ },
+ Err(e) => {
+ // if reconnection failure, respond to all current messages with lost connection error.
+ write_state.stanza_recv.close();
+ if let Some(msg) = retry_msg {
+ msg.respond_to.send(Err(WriteError::LostConnection));
+ }
+ while let Some(msg) = write_state.stanza_recv.recv().await {
+ msg.respond_to.send(Err(WriteError::LostConnection));
+ }
+ // TODO: is this the correct error?
+ self.logic.handle_connection_error(ConnectionError::LostConnection).await;
+ break;
+ },
+ }
+ },
+ else => break,
+ }
+ }
+ // TODO: maybe don't just on_crash
+ let _ = self.on_crash.send(());
+ }
+}
+
+pub struct SupervisorHandle {
+ sender: SupervisorSender,
+ handle: JoinHandle<()>,
+}
+
+impl Deref for SupervisorHandle {
+ type Target = SupervisorSender;
+
+ fn deref(&self) -> &Self::Target {
+ &self.sender
+ }
+}
+
+impl DerefMut for SupervisorHandle {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.sender
+ }
+}
+
+#[derive(Clone)]
+pub struct SupervisorSender {
+ sender: mpsc::Sender<SupervisorCommand>,
+}
+
+impl Deref for SupervisorSender {
+ type Target = mpsc::Sender<SupervisorCommand>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.sender
+ }
+}
+
+impl DerefMut for SupervisorSender {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.sender
+ }
+}
+
+impl SupervisorHandle {
+ pub fn new<Lgc: Logic + Clone + Send + 'static>(
+ streams: BoundJabberStream<Tls>,
+ on_crash: oneshot::Sender<()>,
+ jid: JID,
+ password: Arc<String>,
+ logic: Lgc,
+ ) -> (WriteHandle, Self) {
+ let (command_send, command_recv) = mpsc::channel(20);
+ let (writer_crash_send, writer_crash_recv) = oneshot::channel();
+ let (reader_crash_send, reader_crash_recv) = oneshot::channel();
+
+ let (read_stream, write_stream) = streams.split();
+
+ let (write_handle, write_control_handle) =
+ WriteControlHandle::new(write_stream, writer_crash_send);
+
+ let connected = Connected {
+ jid,
+ write_handle: write_handle.clone(),
+ };
+
+ let supervisor_sender = SupervisorSender {
+ sender: command_send,
+ };
+
+ let read_control_handle = ReadControlHandle::new(
+ read_stream,
+ connected.clone(),
+ logic.clone(),
+ supervisor_sender.clone(),
+ reader_crash_send,
+ );
+
+ let actor = Supervisor::new(
+ command_recv,
+ reader_crash_recv,
+ writer_crash_recv,
+ read_control_handle,
+ write_control_handle,
+ on_crash,
+ connected,
+ password,
+ logic,
+ );
+
+ let handle = tokio::spawn(async move { actor.run().await });
+
+ (
+ write_handle,
+ Self {
+ sender: supervisor_sender,
+ handle,
+ },
+ )
+ }
+
+ pub fn sender(&self) -> SupervisorSender {
+ self.sender.clone()
+ }
+}