aboutsummaryrefslogblamecommitdiffstats
path: root/lampada/src/connection/write.rs
blob: 8f0c34b6809683b9c8cfbeccf7d209551e2cef05 (plain) (tree)
1
2
3
4
5
6
7
8
9

                                
                                                                           





                           
                             
 
                                                                                                                                                                        
                  


                                      
                                                  

                    
                                                   





                                                                                                                                                              

 
                
                         
                       
                                                            

 
                       
               
                                       



            
                                       

                                                       
                                                              

               
                   

                             
                     






                                                                              







                                                                                                       





                                                              













                                                                     





                                                             
                                                                                                                   

                                                                                                
                                                                                      


                                                                                  

                                                                                                                                             
                                                                                                         
                                                                                                     
                                                                                                             
                                             

                                                  
                                                                                                            
                                              
                                                                                       


                                          
                                                                            


                                     
                                                                  

                                  
                                                                                                                                           
                                                        
                                                                                                  







                                                                      
                                                                  
                                                                                                                   
                                                                                                                    


                                      
                                                                           


                              
                                                                


                         










                                       
                  
                                                                         





                                              
                                                   
                        
                                                            


     















                                                  














                                                  




                                       
                                                              



                                                                   
                                                                                    












                                                                    

                                       
                                                              
                                                      
                                


                                                                   
                                                                                    







                                                                                         

                                       
                                                              



                                                                   
                                                                                    







                                                                    
use std::ops::{Deref, DerefMut};

use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberWriter};
use stanza::client::Stanza;
use tokio::{
    sync::{mpsc, oneshot},
    task::JoinHandle,
};

use crate::error::WriteError;

/// actor that receives jabber stanzas to write, and if there is an error, sends a message back to the supervisor then aborts, so the supervisor can spawn a new stream.
pub struct Write {
    stream: BoundJabberWriter<Tls>,

    /// connection session write queue
    stanza_receiver: mpsc::Receiver<WriteMessage>,

    // control stuff
    control_receiver: mpsc::Receiver<WriteControl>,
    on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
}

/// when a crash/abort occurs, this gets sent back to the supervisor, possibly with the current write that failed, so that the connection session can continue
pub struct WriteState {
    pub stanza_recv: mpsc::Receiver<WriteMessage>,
}

#[derive(Debug)]
pub struct WriteMessage {
    pub stanza: Stanza,
    pub respond_to: oneshot::Sender<Result<(), WriteError>>,
}

pub enum WriteControl {
    Disconnect,
    Abort(oneshot::Sender<WriteState>),
}

impl Write {
    fn new(
        stream: BoundJabberWriter<Tls>,
        stanza_receiver: mpsc::Receiver<WriteMessage>,
        control_receiver: mpsc::Receiver<WriteControl>,
        on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
    ) -> Self {
        Self {
            stream,
            stanza_receiver,
            control_receiver,
            on_crash,
        }
    }

    async fn write(&mut self, stanza: &Stanza) -> Result<(), peanuts::Error> {
        Ok(self.stream.write(stanza).await?)
    }

    async fn run_reconnected(mut self, retry_msg: WriteMessage) {
        // try to retry sending the message that failed to send previously
        let result = self.stream.write(&retry_msg.stanza).await;
        match result {
            Err(e) => match &e {
                peanuts::Error::ReadError(_error) => {
                    // make sure message is not lost from error, supervisor handles retry and reporting
                    // TODO: upon reconnect, make sure we are not stuck in a reconnection loop
                    let _ = self.on_crash.send((
                        retry_msg,
                        WriteState {
                            stanza_recv: self.stanza_receiver,
                        },
                    ));
                    return;
                }
                _ => {
                    let _ = retry_msg.respond_to.send(Err(e.into()));
                }
            },
            _ => {
                let _ = retry_msg.respond_to.send(Ok(()));
            }
        }
        // return to normal loop
        self.run().await
    }

    async fn run(mut self) {
        loop {
            tokio::select! {
                Some(msg) = self.control_receiver.recv() => {
                    match msg {
                        WriteControl::Disconnect => {
                            // close the stanza_receiver channel and drain out all of the remaining stanzas to send
                            self.stanza_receiver.close();
                            // TODO: put this in some kind of function to avoid code duplication
                            while let Some(msg) =  self.stanza_receiver.recv().await {
                                let result = self.stream.write(&msg.stanza).await;
                                match result {
                                    Err(e) => match &e {
                                        peanuts::Error::ReadError(_error) => {
                                            // if connection lost during disconnection, just send lost connection error to the write requests
                                            let _ = msg.respond_to.send(Err(WriteError::LostConnection));
                                            while let Some(msg) = self.stanza_receiver.recv().await {
                                                let _ = msg.respond_to.send(Err(WriteError::LostConnection));
                                            }
                                            break;
                                        }
                                        // otherwise complete sending all the stanzas currently in the queue
                                        _ => {
                                            let _ = msg.respond_to.send(Err(e.into()));
                                        }
                                    },
                                    _ => {
                                        let _ = msg.respond_to.send(Ok(()));
                                    }
                                }
                            }
                            let _ = self.stream.try_close().await;
                            break;
                        },
                        // in case of abort, stream is already fucked, just send the receiver ready for a reconnection at the same resource
                        WriteControl::Abort(sender) => {
                            let _ = sender.send(WriteState { stanza_recv: self.stanza_receiver });
                            break;
                        },
                    }
                },
                Some(msg) = self.stanza_receiver.recv() => {
                    let result = self.stream.write(&msg.stanza).await;
                    match result {
                        Err(e) => match &e {
                            peanuts::Error::ReadError(_error) => {
                                // make sure message is not lost from error, supervisor handles retry and reporting
                                let _ = self.on_crash.send((msg, WriteState { stanza_recv: self.stanza_receiver }));
                                break;
                            }
                            _ => {
                                let _ = msg.respond_to.send(Err(e.into()));
                            }
                        },
                        _ => {
                            let _ = msg.respond_to.send(Ok(()));
                        }
                    }
                },
                else => break,
            }
        }
    }
}

#[derive(Clone)]
pub struct WriteHandle {
    sender: mpsc::Sender<WriteMessage>,
}

impl WriteHandle {
    pub async fn write(&self, stanza: Stanza) -> Result<(), WriteError> {
        let (send, recv) = oneshot::channel();
        self.send(WriteMessage {
            stanza,
            respond_to: send,
        })
        .await
        .map_err(|e| WriteError::Actor(e.into()))?;
        // TODO: timeout
        recv.await.map_err(|e| WriteError::Actor(e.into()))?
    }
}

impl Deref for WriteHandle {
    type Target = mpsc::Sender<WriteMessage>;

    fn deref(&self) -> &Self::Target {
        &self.sender
    }
}

impl DerefMut for WriteHandle {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.sender
    }
}

pub struct WriteControlHandle {
    sender: mpsc::Sender<WriteControl>,
    pub(crate) handle: JoinHandle<()>,
}

impl Deref for WriteControlHandle {
    type Target = mpsc::Sender<WriteControl>;

    fn deref(&self) -> &Self::Target {
        &self.sender
    }
}

impl DerefMut for WriteControlHandle {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.sender
    }
}

impl WriteControlHandle {
    pub fn new(
        stream: BoundJabberWriter<Tls>,
        on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
    ) -> (WriteHandle, Self) {
        let (control_sender, control_receiver) = mpsc::channel(20);
        let (stanza_sender, stanza_receiver) = mpsc::channel(20);

        let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
        let handle = tokio::spawn(async move { actor.run().await });

        (
            WriteHandle {
                sender: stanza_sender,
            },
            Self {
                sender: control_sender,
                handle,
            },
        )
    }

    pub fn reconnect_retry(
        stream: BoundJabberWriter<Tls>,
        on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
        stanza_receiver: mpsc::Receiver<WriteMessage>,
        retry_msg: WriteMessage,
    ) -> Self {
        let (control_sender, control_receiver) = mpsc::channel(20);

        let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
        let handle = tokio::spawn(async move { actor.run_reconnected(retry_msg).await });

        Self {
            sender: control_sender,
            handle,
        }
    }

    pub fn reconnect(
        stream: BoundJabberWriter<Tls>,
        on_crash: oneshot::Sender<(WriteMessage, WriteState)>,
        stanza_receiver: mpsc::Receiver<WriteMessage>,
    ) -> Self {
        let (control_sender, control_receiver) = mpsc::channel(20);

        let actor = Write::new(stream, stanza_receiver, control_receiver, on_crash);
        let handle = tokio::spawn(async move { actor.run().await });

        Self {
            sender: control_sender,
            handle,
        }
    }
}