diff options
Diffstat (limited to '')
| -rw-r--r-- | lampada/src/connection/mod.rs | 52 | ||||
| -rw-r--r-- | lampada/src/connection/read.rs | 60 | ||||
| -rw-r--r-- | lampada/src/connection/write.rs | 14 | ||||
| -rw-r--r-- | lampada/src/lib.rs | 7 | 
4 files changed, 68 insertions, 65 deletions
| diff --git a/lampada/src/connection/mod.rs b/lampada/src/connection/mod.rs index 1e767b0..ffaa7a7 100644 --- a/lampada/src/connection/mod.rs +++ b/lampada/src/connection/mod.rs @@ -10,7 +10,7 @@ use std::{  use jid::JID;  use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream};  use read::{ReadControl, ReadControlHandle, ReadState}; -use stanza::client::Stanza; +use stanza::{client::Stanza, stream_error::Error as StreamError};  use tokio::{      sync::{mpsc, oneshot, Mutex},      task::{JoinHandle, JoinSet}, @@ -28,7 +28,7 @@ pub(crate) mod write;  pub struct Supervisor<Lgc> {      command_recv: mpsc::Receiver<SupervisorCommand>, -    reader_crash: oneshot::Receiver<ReadState>, +    reader_crash: oneshot::Receiver<(Option<StreamError>, ReadState)>,      writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,      read_control_handle: ReadControlHandle,      write_control_handle: WriteControlHandle, @@ -43,18 +43,13 @@ pub enum SupervisorCommand {      Disconnect,      // for if there was a stream error, require to reconnect      // couldn't stream errors just cause a crash? lol -    Reconnect(ChildState), -} - -pub enum ChildState { -    Write(WriteState), -    Read(ReadState), +    Reconnect(ReadState),  }  impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {      fn new(          command_recv: mpsc::Receiver<SupervisorCommand>, -        reader_crash: oneshot::Receiver<ReadState>, +        reader_crash: oneshot::Receiver<(Option<StreamError>, ReadState)>,          writer_crash: oneshot::Receiver<(WriteMessage, WriteState)>,          read_control_handle: ReadControlHandle,          write_control_handle: WriteControlHandle, @@ -104,33 +99,19 @@ impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {                              break;                          },                          // TODO: Reconnect without aborting, gentle reconnect. +                        // the server sent a stream error                          SupervisorCommand::Reconnect(state) => {                              // TODO: please omfg                              // send abort to read stream, as already done, consider                              let (read_state, mut write_state); -                            match state { -                                ChildState::Write(receiver) => { -                                    write_state = receiver; -                                    let (send, recv) = oneshot::channel(); -                                    let _ = self.read_control_handle.send(ReadControl::Abort(send)).await; -                                    // TODO: need a tokio select, in case the state arrives from somewhere else -                                    if let Ok(state) = recv.await { -                                        read_state = state; -                                    } else { -                                        break -                                    } -                                }, -                                ChildState::Read(read) => { -                                    read_state = read; -                                    let (send, recv) = oneshot::channel(); -                                    let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; -                                    // TODO: need a tokio select, in case the state arrives from somewhere else -                                    if let Ok(state) = recv.await { -                                        write_state = state; -                                    } else { -                                        break -                                    } -                                }, +                            read_state = state; +                            let (send, recv) = oneshot::channel(); +                            let _ = self.write_control_handle.send(WriteControl::Abort(None, send)).await; +                            // TODO: need a tokio select, in case the state arrives from somewhere else +                            if let Ok(state) = recv.await { +                                write_state = state; +                            } else { +                                break                              }                              let mut jid = self.connected.jid.clone(); @@ -175,7 +156,8 @@ impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {                      let _ = self.read_control_handle.send(ReadControl::Abort(send)).await;                      let read_state = tokio::select! {                          Ok(s) = recv => s, -                        Ok(s) = &mut self.reader_crash => s, +                        // TODO: is this okay +                        Ok(s) = &mut self.reader_crash => s.1,                          // in case, just break as irrecoverable                          else => break,                      }; @@ -215,9 +197,9 @@ impl<Lgc: Logic + Clone + Send + 'static> Supervisor<Lgc> {                          },                      }                  }, -                Ok(read_state) = &mut self.reader_crash => { +                Ok((stream_error, read_state)) = &mut self.reader_crash => {                      let (send, recv) = oneshot::channel(); -                    let _ = self.write_control_handle.send(WriteControl::Abort(send)).await; +                    let _ = self.write_control_handle.send(WriteControl::Abort(stream_error, send)).await;                      let (retry_msg, mut write_state) = tokio::select! {                          Ok(s) = recv => (None, s),                          Ok(s) = &mut self.writer_crash => (Some(s.0), s.1), diff --git a/lampada/src/connection/read.rs b/lampada/src/connection/read.rs index cc69387..640ca8e 100644 --- a/lampada/src/connection/read.rs +++ b/lampada/src/connection/read.rs @@ -9,13 +9,15 @@ use std::{  use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberReader};  use stanza::client::Stanza; +use stanza::stream::Error as StreamErrorStanza; +use stanza::stream_error::Error as StreamError;  use tokio::{      sync::{mpsc, oneshot, Mutex},      task::{JoinHandle, JoinSet},  };  use tracing::info; -use crate::{Connected, Logic}; +use crate::{Connected, Logic, WriteMessage};  use super::{write::WriteHandle, SupervisorCommand, SupervisorSender}; @@ -36,7 +38,7 @@ pub struct Read<Lgc> {      // control stuff      control_receiver: mpsc::Receiver<ReadControl>, -    on_crash: oneshot::Sender<ReadState>, +    on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,  }  /// when a crash/abort occurs, this gets sent back to the supervisor, so that the connection session can continue @@ -54,7 +56,7 @@ impl<Lgc> Read<Lgc> {          logic: Lgc,          supervisor_control: SupervisorSender,          control_receiver: mpsc::Receiver<ReadControl>, -        on_crash: oneshot::Sender<ReadState>, +        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,      ) -> Self {          let (_send, recv) = oneshot::channel();          Self { @@ -106,34 +108,40 @@ impl<Lgc: Clone + Logic + Send + 'static> Read<Lgc> {                      println!("read stanza");                      match s {                          Ok(s) => { -                            self.tasks.spawn(self.logic.clone().handle_stanza(s, self.connected.clone(), self.supervisor_control.clone())); +                            match s { +                                Stanza::Error(error) => { +                                    self.logic.clone().handle_stream_error(error).await; +                                    self.supervisor_control.send(SupervisorCommand::Reconnect(ReadState { supervisor_control: self.supervisor_control.clone(), tasks: self.tasks })).await; +                                    break; +                                }, +                                _ => { +                                    self.tasks.spawn(self.logic.clone().handle_stanza(s, self.connected.clone())); +                                } +                            };                          },                          Err(e) => {                              println!("error: {:?}", e); -                            // TODO: NEXT write the correct error stanza depending on error, decide whether to reconnect or properly disconnect, depending on if disconnecting is true -                            // match e { -                            //     peanuts::Error::ReadError(error) => todo!(), -                            //     peanuts::Error::Utf8Error(utf8_error) => todo!(), -                            //     peanuts::Error::ParseError(_) => todo!(), -                            //     peanuts::Error::EntityProcessError(_) => todo!(), -                            //     peanuts::Error::InvalidCharRef(_) => todo!(), -                            //     peanuts::Error::DuplicateNameSpaceDeclaration(namespace_declaration) => todo!(), -                            //     peanuts::Error::DuplicateAttribute(_) => todo!(), -                            //     peanuts::Error::UnqualifiedNamespace(_) => todo!(), -                            //     peanuts::Error::MismatchedEndTag(name, name1) => todo!(), -                            //     peanuts::Error::NotInElement(_) => todo!(), -                            //     peanuts::Error::ExtraData(_) => todo!(), -                            //     peanuts::Error::UndeclaredNamespace(_) => todo!(), -                            //     peanuts::Error::IncorrectName(name) => todo!(), -                            //     peanuts::Error::DeserializeError(_) => todo!(), -                            //     peanuts::Error::Deserialize(deserialize_error) => todo!(), -                            //     peanuts::Error::RootElementEnded => todo!(), -                            // }                              // TODO: make sure this only happens when an end tag is received                              if self.disconnecting == true {                                  break;                              } else { -                                let _ = self.on_crash.send(ReadState { supervisor_control: self.supervisor_control, tasks: self.tasks }); +                                let stream_error = match e { +                                    peanuts::Error::ReadError(error) => None, +                                    peanuts::Error::Utf8Error(utf8_error) => Some(StreamError::UnsupportedEncoding), +                                    peanuts::Error::ParseError(_) => Some(StreamError::BadFormat), +                                    peanuts::Error::EntityProcessError(_) => Some(StreamError::RestrictedXml), +                                    peanuts::Error::InvalidCharRef(char_ref_error) => Some(StreamError::UnsupportedEncoding), +                                    peanuts::Error::DuplicateNameSpaceDeclaration(namespace_declaration) => Some(StreamError::NotWellFormed), +                                    peanuts::Error::DuplicateAttribute(_) => Some(StreamError::NotWellFormed), +                                    peanuts::Error::MismatchedEndTag(name, name1) => Some(StreamError::NotWellFormed), +                                    peanuts::Error::NotInElement(_) => Some(StreamError::InvalidXml), +                                    peanuts::Error::ExtraData(_) => None, +                                    peanuts::Error::UndeclaredNamespace(_) => Some(StreamError::InvalidNamespace), +                                    peanuts::Error::Deserialize(deserialize_error) => Some(StreamError::InvalidXml), +                                    peanuts::Error::RootElementEnded => Some(StreamError::InvalidXml), +                                }; + +                                let _ = self.on_crash.send((stream_error, ReadState { supervisor_control: self.supervisor_control, tasks: self.tasks }));                              }                              break;                          }, @@ -183,7 +191,7 @@ impl ReadControlHandle {          connected: Connected,          logic: Lgc,          supervisor_control: SupervisorSender, -        on_crash: oneshot::Sender<ReadState>, +        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,      ) -> Self {          let (control_sender, control_receiver) = mpsc::channel(20); @@ -210,7 +218,7 @@ impl ReadControlHandle {          connected: Connected,          logic: Lgc,          supervisor_control: SupervisorSender, -        on_crash: oneshot::Sender<ReadState>, +        on_crash: oneshot::Sender<(Option<StreamError>, ReadState)>,      ) -> Self {          let (control_sender, control_receiver) = mpsc::channel(20); diff --git a/lampada/src/connection/write.rs b/lampada/src/connection/write.rs index 8f0c34b..1070cdf 100644 --- a/lampada/src/connection/write.rs +++ b/lampada/src/connection/write.rs @@ -1,7 +1,9 @@  use std::ops::{Deref, DerefMut};  use luz::{connection::Tls, jabber_stream::bound_stream::BoundJabberWriter}; -use stanza::client::Stanza; +use stanza::{ +    client::Stanza, stream::Error as StreamErrorStanza, stream_error::Error as StreamError, +};  use tokio::{      sync::{mpsc, oneshot},      task::JoinHandle, @@ -34,7 +36,7 @@ pub struct WriteMessage {  pub enum WriteControl {      Disconnect, -    Abort(oneshot::Sender<WriteState>), +    Abort(Option<StreamError>, oneshot::Sender<WriteState>),  }  impl Write { @@ -119,7 +121,13 @@ impl Write {                              break;                          },                          // in case of abort, stream is already fucked, just send the receiver ready for a reconnection at the same resource -                        WriteControl::Abort(sender) => { +                        WriteControl::Abort(error, sender) => { +                            // write stream error message for server if there is one +                            if let Some(error) = error { +                                // TODO: timeouts for writing to stream +                                let _ = self.stream.write(&Stanza::Error(StreamErrorStanza { error, text: None })).await; +                                // don't care about result, if it sends it sends, otherwise stream is restarting anyway +                            }                              let _ = sender.send(WriteState { stanza_recv: self.stanza_receiver });                              break;                          }, diff --git a/lampada/src/lib.rs b/lampada/src/lib.rs index c61c596..a01ba06 100644 --- a/lampada/src/lib.rs +++ b/lampada/src/lib.rs @@ -15,6 +15,7 @@ use stanza::client::{      iq::{self, Iq, IqType},      Stanza,  }; +use stanza::stream::Error as StreamError;  use tokio::{      sync::{mpsc, oneshot, Mutex},      task::JoinSet, @@ -59,12 +60,16 @@ pub trait Logic {          connection: Connected,      ) -> impl std::future::Future<Output = ()> + Send; +    fn handle_stream_error( +        self, +        stream_error: StreamError, +    ) -> impl std::future::Future<Output = ()> + Send; +      /// run to handle an incoming xmpp stanza      fn handle_stanza(          self,          stanza: Stanza,          connection: Connected, -        supervisor: SupervisorSender,      ) -> impl std::future::Future<Output = ()> + std::marker::Send;      /// run to handle a command message when a connection is currently established | 
