aboutsummaryrefslogblamecommitdiffstats
path: root/lampada/src/connection/read.rs
blob: 640ca8e0bbcdf727d953678edbd810edbb6ff212 (plain) (tree)
1
2
3
4
5
6
7
8
9
          
                         
                        
                           
                 
              


                   
                                                                           
                           

                                               
            
                                 

                                
                  
 
                                            
 
                                                                     
 
              
                      
                                   

                                               






                                                                                                                                        
               



                                                  
                                                                






                                                                                                                 

 
                     
           
                                       
                           
                             
                   

                                                      
                                                                    
               
                                               
              
                   

                                      





                               

         
 
 
                                                     
                            


                                                           

                            
                                                                                  

                                                                          
                                                 

                          

                                                             










                                                                                  
                                                                                                                              

                                  

                      


                                                     
                                  









                                                                                                                                                                                           

                                   
                                                       



                                                                                            
















                                                                                                                                                         


                                  
                     


                             

                                         
                                    


     




                                 
 
                      
               
                                      



                                      














                                                  


                        
                                                    
                                       
                             
                   
                                             
                                                                    



                                                                   
                   
                           




                               








                                                                    
                                                          
                                       
                           
                             
                   
                                             
                                                                    



                                                                   
                   
                  




                               








                                                                    
use std::{
    collections::HashMap,
    marker::PhantomData,
    ops::{Deref, DerefMut},
    str::FromStr,
    sync::Arc,
    time::Duration,
};

use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberReader};
use stanza::client::Stanza;
use stanza::stream::Error as StreamErrorStanza;
use stanza::stream_error::Error as StreamError;
use tokio::{
    sync::{mpsc, oneshot, Mutex},
    task::{JoinHandle, JoinSet},
};
use tracing::info;

use crate::{Connected, Logic, WriteMessage};

use super::{write::WriteHandle, SupervisorCommand, SupervisorSender};

/// read actor
pub struct Read<Lgc> {
    stream: BoundJabberReader<Tls>,
    disconnecting: bool,
    disconnect_timedout: oneshot::Receiver<()>,

    // all the threads spawned by the current connection session
    tasks: JoinSet<()>,

    // for handling incoming stanzas
    // jabber server must be able to both terminate the connection from error, and ask for data from the client (such as supported XEPs)
    connected: Connected,
    logic: Lgc,
    supervisor_control: SupervisorSender,

    // control stuff
    control_receiver: mpsc::Receiver<ReadControl>,
    on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,
}

/// when a crash/abort occurs, this gets sent back to the supervisor, so that the connection session can continue
pub struct ReadState {
    pub supervisor_control: SupervisorSender,
    // TODO: when a stream dies, the iq gets from the server should not be replied to on the new stream
    pub tasks: JoinSet<()>,
}

impl<Lgc> Read<Lgc> {
    fn new(
        stream: BoundJabberReader<Tls>,
        tasks: JoinSet<()>,
        connected: Connected,
        logic: Lgc,
        supervisor_control: SupervisorSender,
        control_receiver: mpsc::Receiver<ReadControl>,
        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,
    ) -> Self {
        let (_send, recv) = oneshot::channel();
        Self {
            stream,
            disconnecting: false,
            disconnect_timedout: recv,
            tasks,
            connected,
            logic,
            supervisor_control,
            control_receiver,
            on_crash,
        }
    }
}

impl<Lgc: Clone + Logic + Send + 'static> Read<Lgc> {
    async fn run(mut self) {
        println!("started read thread");
        // let stanza = self.stream.read::<Stanza>().await;
        // println!("{:?}", stanza);
        loop {
            tokio::select! {
                // if still haven't received the end tag in time, just kill itself
                // TODO: is this okay??? what if notification thread dies?
                Ok(()) = &mut self.disconnect_timedout => {
                    info!("disconnect_timedout");
                    break;
                }
                Some(msg) = self.control_receiver.recv() => {
                    match msg {
                        // when disconnect received,
                        ReadControl::Disconnect => {
                            let (send, recv) = oneshot::channel();
                            self.disconnect_timedout = recv;
                            self.disconnecting = true;
                            tokio::spawn(async {
                                tokio::time::sleep(Duration::from_secs(10)).await;
                                let _ = send.send(());
                            })
                        },
                        ReadControl::Abort(sender) => {
                            let _ = sender.send(ReadState { supervisor_control: self.supervisor_control, tasks: self.tasks });
                            break;
                        },
                    };
                },
                s = self.stream.read::<Stanza>() => {
                    println!("read stanza");
                    match s {
                        Ok(s) => {
                            match s {
                                Stanza::Error(error) => {
                                    self.logic.clone().handle_stream_error(error).await;
                                    self.supervisor_control.send(SupervisorCommand::Reconnect(ReadState { supervisor_control: self.supervisor_control.clone(), tasks: self.tasks })).await;
                                    break;
                                },
                                _ => {
                                    self.tasks.spawn(self.logic.clone().handle_stanza(s, self.connected.clone()));
                                }
                            };
                        },
                        Err(e) => {
                            println!("error: {:?}", e);
                            // TODO: make sure this only happens when an end tag is received
                            if self.disconnecting == true {
                                break;
                            } else {
                                let stream_error = match e {
                                    peanuts::Error::ReadError(error) => None,
                                    peanuts::Error::Utf8Error(utf8_error) => Some(StreamError::UnsupportedEncoding),
                                    peanuts::Error::ParseError(_) => Some(StreamError::BadFormat),
                                    peanuts::Error::EntityProcessError(_) => Some(StreamError::RestrictedXml),
                                    peanuts::Error::InvalidCharRef(char_ref_error) => Some(StreamError::UnsupportedEncoding),
                                    peanuts::Error::DuplicateNameSpaceDeclaration(namespace_declaration) => Some(StreamError::NotWellFormed),
                                    peanuts::Error::DuplicateAttribute(_) => Some(StreamError::NotWellFormed),
                                    peanuts::Error::MismatchedEndTag(name, name1) => Some(StreamError::NotWellFormed),
                                    peanuts::Error::NotInElement(_) => Some(StreamError::InvalidXml),
                                    peanuts::Error::ExtraData(_) => None,
                                    peanuts::Error::UndeclaredNamespace(_) => Some(StreamError::InvalidNamespace),
                                    peanuts::Error::Deserialize(deserialize_error) => Some(StreamError::InvalidXml),
                                    peanuts::Error::RootElementEnded => Some(StreamError::InvalidXml),
                                };

                                let _ = self.on_crash.send((stream_error, ReadState { supervisor_control: self.supervisor_control, tasks: self.tasks }));
                            }
                            break;
                        },
                    }
                },
                else => break
            }
        }
        println!("stopping read thread");
        self.logic.on_abort().await;
    }
}

// what do stanza processes do?
// - update ui
// - access database
// - disconnect proper, reconnect
// - respond to server requests

pub enum ReadControl {
    Disconnect,
    Abort(oneshot::Sender<ReadState>),
}

pub struct ReadControlHandle {
    sender: mpsc::Sender<ReadControl>,
    pub(crate) handle: JoinHandle<()>,
}

impl Deref for ReadControlHandle {
    type Target = mpsc::Sender<ReadControl>;

    fn deref(&self) -> &Self::Target {
        &self.sender
    }
}

impl DerefMut for ReadControlHandle {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.sender
    }
}

impl ReadControlHandle {
    pub fn new<Lgc: Clone + Logic + Send + 'static>(
        stream: BoundJabberReader<Tls>,
        connected: Connected,
        logic: Lgc,
        supervisor_control: SupervisorSender,
        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,
    ) -> Self {
        let (control_sender, control_receiver) = mpsc::channel(20);

        let actor = Read::new(
            stream,
            JoinSet::new(),
            connected,
            logic,
            supervisor_control,
            control_receiver,
            on_crash,
        );
        let handle = tokio::spawn(async move { actor.run().await });

        Self {
            sender: control_sender,
            handle,
        }
    }

    pub fn reconnect<Lgc: Clone + Logic + Send + 'static>(
        stream: BoundJabberReader<Tls>,
        tasks: JoinSet<()>,
        connected: Connected,
        logic: Lgc,
        supervisor_control: SupervisorSender,
        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,
    ) -> Self {
        let (control_sender, control_receiver) = mpsc::channel(20);

        let actor = Read::new(
            stream,
            tasks,
            connected,
            logic,
            supervisor_control,
            control_receiver,
            on_crash,
        );
        let handle = tokio::spawn(async move { actor.run().await });

        Self {
            sender: control_sender,
            handle,
        }
    }
}