aboutsummaryrefslogblamecommitdiffstats
path: root/lampada/src/connection/mod.rs
blob: 1e767b0595e24d66bf44cc451d12fa11be1d3bae (plain) (tree)
1
2
3
4
5
6
7
8
9


                                                                                                 
                         



                           
 
             
                                                                           
                                                      
                           
            
                                 

                                
                  
                                                                                     
 
            

                                         
  



                     
                            







                                                                                               
                          
               




                                                            
                                                     
                          

 


                      

 
                                                           
           






                                                                    
                              
                   

               
                         
                         




                                 
                     
                  


         


                            
                                                         

                                                          
                                                   


                                                                          

                                                                                                   
                                                             

                                                         

                                                                                                     
                                          
                                                       
                                                                                             

                                                                                    

                                 
                                                  

                                  
                                                                              
                                                                

                                                                                   

                                                              
                                                                

                                                                          
                                                                                                          
                                                                                                               





                                                                   
                                                           

                                                                          
                                                                                                            








                                                                                                               
                                                                     
                                                                    
                                                                                                                                           
                                                                                                                  

                                              


                                                                          

                                                                                                            

                                                                          
                                                                                            
                                             



                                                                      
                                             



                                                                                                                                                                                    

                                                                                                
                                                                                                     
                                     
                                                                       
                                                                                                              


                                          

                          
                  
                                                                              

                                                                            

                                                                                          




                                                               
 
                                                             
                                                            
                                      
                                                                                                          




                                                                  

                                                                                                                     

                                                                  
                                                                                    
                                     



                                                              
                                     



                                                                                                                                                                            
                                                            
                                                                                               
                                                                                        
                                                                                             
                             
                                                                       
                                                                                                      


                                  
                  
                                                            
                                                          

                                                                                            




                                                                           
 
                                                             
                                                            
                                                                                                          





                                                                  

                                                                                                                   
                                    
                                                                                                                               


                                                                  
                                                                                    
                                     



                                                              
                                     



                                                                                                                   
                                                            
                                                          
                                                                                     
                             
                                                                                        
                                                                                     
                             
                                                               
                                                                                                      


                                  



                              

                                          









































                                                  
                                                    
                                        

                                      
                              
                   
                              




                                                                        
 
                                                  
















                                                                     


                                    



                                
                                 

                      
                     
                  






                                                                    
                                          








                                              
// TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly

use std::{
    collections::HashMap,
    ops::{Deref, DerefMut},
    sync::Arc,
    time::Duration,
};

use jid::JID;
use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream};
use read::{ReadControl, ReadControlHandle, ReadState};
use stanza::client::Stanza;
use tokio::{
    sync::{mpsc, oneshot, Mutex},
    task::{JoinHandle, JoinSet},
};
use tracing::info;
use write::{WriteControl, WriteControlHandle, WriteHandle, WriteMessage, WriteState};

use crate::{
    error::{ConnectionError, WriteError},
    Connected, Logic,
};

mod read;
pub(crate) mod write;

pub struct Supervisor<Lgc> {
    command_recv: mpsc::Receiver<SupervisorCommand>,
    reader_crash: oneshot::Receiver<ReadState>,
    writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
    read_control_handle: ReadControlHandle,
    write_control_handle: WriteControlHandle,
    on_crash: oneshot::Sender<()>,
    // jid in connected stays the same over the life of the supervisor (the connection session)
    connected: Connected,
    password: Arc<String>,
    logic: Lgc,
}

pub enum SupervisorCommand {
    Disconnect,
    // for if there was a stream error, require to reconnect
    // couldn't stream errors just cause a crash? lol
    Reconnect(ChildState),
}

pub enum ChildState {
    Write(WriteState),
    Read(ReadState),
}

impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {
    fn new(
        command_recv: mpsc::Receiver<SupervisorCommand>,
        reader_crash: oneshot::Receiver<ReadState>,
        writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,
        read_control_handle: ReadControlHandle,
        write_control_handle: WriteControlHandle,
        on_crash: oneshot::Sender<()>,
        connected: Connected,
        password: Arc<String>,
        logic: Lgc,
    ) -> Self {
        Self {
            command_recv,
            reader_crash,
            writer_crash,
            read_control_handle,
            write_control_handle,
            on_crash,
            connected,
            password,
            logic,
        }
    }

    async fn run(mut self) {
        loop {
            tokio::select! {
                Some(msg) = self.command_recv.recv() => {
                    match msg {
                        SupervisorCommand::Disconnect => {
                            info!("disconnecting");
                            self.logic
                                .handle_disconnect(self.connected.clone())
                                .await;
                            let _ = self.write_control_handle.send(WriteControl::Disconnect).await;
                            let _ = self.read_control_handle.send(ReadControl::Disconnect).await;
                            info!("sent disconnect command");
                            tokio::select! {
                                _ = async { tokio::join!(
                                    async { let _ = (&mut self.write_control_handle.handle).await; },
                                    async { let _ = (&mut self.read_control_handle.handle).await; }
                                ) } => {},
                                // TODO: config timeout
                                _ = async { tokio::time::sleep(Duration::from_secs(5)) } => {
                                    (&mut self.read_control_handle.handle).abort();
                                    (&mut self.write_control_handle.handle).abort();
                                }
                            }
                            info!("disconnected");
                            break;
                        },
                        // TODO: Reconnect without aborting, gentle reconnect.
                        SupervisorCommand::Reconnect(state) => {
                            // TODO: please omfg
                            // send abort to read stream, as already done, consider
                            let (read_state, mut write_state);
                            match state {
                                ChildState::Write(receiver) => {
                                    write_state = receiver;
                                    let (send, recv) = oneshot::channel();
                                    let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
                                    // TODO: need a tokio select, in case the state arrives from somewhere else
                                    if let Ok(state) = recv.await {
                                        read_state = state;
                                    } else {
                                        break
                                    }
                                },
                                ChildState::Read(read) => {
                                    read_state = read;
                                    let (send, recv) = oneshot::channel();
                                    let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
                                    // TODO: need a tokio select, in case the state arrives from somewhere else
                                    if let Ok(state) = recv.await {
                                        write_state = state;
                                    } else {
                                        break
                                    }
                                },
                            }

                            let mut jid = self.connected.jid.clone();
                            let mut domain = jid.domainpart.clone();
                            // TODO: make sure connect_and_login does not modify the jid, but instead returns a jid. or something like that
                            let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
                            match connection {
                                Ok(c) => {
                                    let (read, write) = c.split();
                                    let (send, recv) = oneshot::channel();
                                    self.writer_crash = recv;
                                    self.write_control_handle =
                                        WriteControlHandle::reconnect(write, send, write_state.stanza_recv);
                                    let (send, recv) = oneshot::channel();
                                    self.reader_crash = recv;
                                    self.read_control_handle = ReadControlHandle::reconnect(
                                        read,
                                        read_state.tasks,
                                        self.connected.clone(),
                                        self.logic.clone(),
                                        read_state.supervisor_control,
                                        send,
                                    );
                                },
                                Err(e) => {
                                    // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
                                    write_state.stanza_recv.close();
                                    while let Some(msg) = write_state.stanza_recv.recv().await {
                                        let _ = msg.respond_to.send(Err(WriteError::LostConnection));
                                    }
                                    // TODO: is this the correct error?
                                    self.logic.handle_connection_error(ConnectionError::LostConnection).await;
                                    break;
                                },
                            }
                        },
                    }
                },
                Ok((write_msg, mut write_state)) = &mut self.writer_crash => {
                    // consider awaiting/aborting the read and write threads
                    let (send, recv) = oneshot::channel();
                    let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;
                    let read_state = tokio::select! {
                        Ok(s) = recv => s,
                        Ok(s) = &mut self.reader_crash => s,
                        // in case, just break as irrecoverable
                        else => break,
                    };

                    let mut jid = self.connected.jid.clone();
                    let mut domain = jid.domainpart.clone();
                    // TODO: same here
                    let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
                    match connection {
                        Ok(c) => {
                            let (read, write) = c.split();
                            let (send, recv) = oneshot::channel();
                            self.writer_crash = recv;
                            self.write_control_handle =
                                WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, write_msg);
                            let (send, recv) = oneshot::channel();
                            self.reader_crash = recv;
                            self.read_control_handle = ReadControlHandle::reconnect(
                                read,
                                read_state.tasks,
                                self.connected.clone(),
                                self.logic.clone(),
                                read_state.supervisor_control,
                                send,
                            );
                        },
                        Err(e) => {
                            // if reconnection failure, respond to all current write messages with lost connection error. the received processes should complete themselves.
                            write_state.stanza_recv.close();
                            let _ = write_msg.respond_to.send(Err(WriteError::LostConnection));
                            while let Some(msg) = write_state.stanza_recv.recv().await {
                                let _ = msg.respond_to.send(Err(WriteError::LostConnection));
                            }
                            // TODO: is this the correct error to send?
                            self.logic.handle_connection_error(ConnectionError::LostConnection).await;
                            break;
                        },
                    }
                },
                Ok(read_state) = &mut self.reader_crash => {
                    let (send, recv) = oneshot::channel();
                    let _ = self.write_control_handle.send(WriteControl::Abort(send)).await;
                    let (retry_msg, mut write_state) = tokio::select! {
                        Ok(s) = recv => (None, s),
                        Ok(s) = &mut self.writer_crash => (Some(s.0), s.1),
                        // in case, just break as irrecoverable
                        else => break,
                    };

                    let mut jid = self.connected.jid.clone();
                    let mut domain = jid.domainpart.clone();
                    let connection = luz::connect_and_login(&mut jid, &*self.password, &mut domain).await;
                    match connection {
                        Ok(c) => {
                            let (read, write) = c.split();
                            let (send, recv) = oneshot::channel();
                            self.writer_crash = recv;
                            if let Some(msg) = retry_msg {
                                self.write_control_handle =
                                    WriteControlHandle::reconnect_retry(write, send, write_state.stanza_recv, msg);
                            } else {
                                self.write_control_handle = WriteControlHandle::reconnect(write, send, write_state.stanza_recv)
                            }
                            let (send, recv) = oneshot::channel();
                            self.reader_crash = recv;
                            self.read_control_handle = ReadControlHandle::reconnect(
                                read,
                                read_state.tasks,
                                self.connected.clone(),
                                self.logic.clone(),
                                read_state.supervisor_control,
                                send,
                            );
                        },
                        Err(e) => {
                            // if reconnection failure, respond to all current messages with lost connection error.
                            write_state.stanza_recv.close();
                            if let Some(msg) = retry_msg {
                                msg.respond_to.send(Err(WriteError::LostConnection));
                            }
                            while let Some(msg) = write_state.stanza_recv.recv().await {
                                msg.respond_to.send(Err(WriteError::LostConnection));
                            }
                            // TODO: is this the correct error?
                            self.logic.handle_connection_error(ConnectionError::LostConnection).await;
                            break;
                        },
                    }
                },
                else => break,
            }
        }
        // TODO: maybe don't just on_crash
        let _ = self.on_crash.send(());
    }
}

pub struct SupervisorHandle {
    sender: SupervisorSender,
    handle: JoinHandle<()>,
}

impl Deref for SupervisorHandle {
    type Target = SupervisorSender;

    fn deref(&self) -> &Self::Target {
        &self.sender
    }
}

impl DerefMut for SupervisorHandle {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.sender
    }
}

#[derive(Clone)]
pub struct SupervisorSender {
    sender: mpsc::Sender<SupervisorCommand>,
}

impl Deref for SupervisorSender {
    type Target = mpsc::Sender<SupervisorCommand>;

    fn deref(&self) -> &Self::Target {
        &self.sender
    }
}

impl DerefMut for SupervisorSender {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.sender
    }
}

impl SupervisorHandle {
    pub fn new<Lgc: Logic + Clone + Send + 'static>(
        streams: BoundJabberStream<Tls>,
        on_crash: oneshot::Sender<()>,
        jid: JID,
        password: Arc<String>,
        logic: Lgc,
    ) -> (WriteHandle, Self) {
        let (command_send, command_recv) = mpsc::channel(20);
        let (writer_crash_send, writer_crash_recv) = oneshot::channel();
        let (reader_crash_send, reader_crash_recv) = oneshot::channel();

        let (read_stream, write_stream) = streams.split();

        let (write_handle, write_control_handle) =
            WriteControlHandle::new(write_stream, writer_crash_send);

        let connected = Connected {
            jid,
            write_handle: write_handle.clone(),
        };

        let supervisor_sender = SupervisorSender {
            sender: command_send,
        };

        let read_control_handle = ReadControlHandle::new(
            read_stream,
            connected.clone(),
            logic.clone(),
            supervisor_sender.clone(),
            reader_crash_send,
        );

        let actor = Supervisor::new(
            command_recv,
            reader_crash_recv,
            writer_crash_recv,
            read_control_handle,
            write_control_handle,
            on_crash,
            connected,
            password,
            logic,
        );

        let handle = tokio::spawn(async move { actor.run().await });

        (
            write_handle,
            Self {
                sender: supervisor_sender,
                handle,
            },
        )
    }

    pub fn sender(&self) -> SupervisorSender {
        self.sender.clone()
    }
}