diff options
| author | 2025-02-11 10:54:16 +0000 | |
|---|---|---|
| committer | 2025-02-11 10:54:16 +0000 | |
| commit | 36348285317f6e073581479821564ddf825777c7 (patch) | |
| tree | 8169857e18c68af8dbb1be5dc1f79a047ae2b9e8 | |
| parent | 1ed6317272fe819e7e12b1be6fcff62d409c8f03 (diff) | |
| download | luz-36348285317f6e073581479821564ddf825777c7.tar.gz luz-36348285317f6e073581479821564ddf825777c7.tar.bz2 luz-36348285317f6e073581479821564ddf825777c7.zip | |
add iq hashmap for iq requests
Diffstat (limited to '')
| -rw-r--r-- | luz/src/connection/mod.rs | 16 | ||||
| -rw-r--r-- | luz/src/connection/read.rs | 26 | ||||
| -rw-r--r-- | luz/src/lib.rs | 9 | 
3 files changed, 42 insertions, 9 deletions
| diff --git a/luz/src/connection/mod.rs b/luz/src/connection/mod.rs index 85cf7cc..f8cf18b 100644 --- a/luz/src/connection/mod.rs +++ b/luz/src/connection/mod.rs @@ -1,6 +1,7 @@  // TODO: consider if this needs to be handled by a supervisor or could be handled by luz directly  use std::{ +    collections::HashMap,      ops::{Deref, DerefMut},      sync::Arc,      time::Duration, @@ -10,6 +11,7 @@ use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberStream};  use jid::JID;  use read::{ReadControl, ReadControlHandle};  use sqlx::SqlitePool; +use stanza::client::Stanza;  use tokio::{      sync::{mpsc, oneshot, Mutex},      task::{JoinHandle, JoinSet}, @@ -30,6 +32,7 @@ pub struct Supervisor {          tokio::task::JoinSet<()>,          mpsc::Sender<SupervisorCommand>,          WriteHandle, +        Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      )>,      sender: mpsc::Sender<UpdateMessage>,      writer_handle: WriteControlHandle, @@ -56,6 +59,7 @@ impl Supervisor {              JoinSet<()>,              mpsc::Sender<SupervisorCommand>,              WriteHandle, +            Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,          )>,          sender: mpsc::Sender<UpdateMessage>,          writer_handle: WriteControlHandle, @@ -108,7 +112,7 @@ impl Supervisor {                      // consider awaiting/aborting the read and write threads                      let (send, recv) = oneshot::channel();                      let _ = self.reader_handle.send(ReadControl::Abort(send)).await; -                    let (db, update_sender, tasks, supervisor_command, write_sender) = tokio::select! { +                    let (db, update_sender, tasks, supervisor_command, write_sender, pending_iqs) = tokio::select! {                          Ok(s) = recv => s,                          Ok(s) = &mut self.reader_crash => s,                          // in case, just break as irrecoverable @@ -134,7 +138,8 @@ impl Supervisor {                                  update_sender,                                  supervisor_command,                                  write_sender, -                                tasks +                                tasks, +                                pending_iqs,                              );                          },                          Err(e) => { @@ -149,7 +154,7 @@ impl Supervisor {                          },                      }                  }, -                Ok((db, update_sender, tasks, supervisor_control, write_handle)) = &mut self.reader_crash => { +                Ok((db, update_sender, tasks, supervisor_control, write_handle, pending_iqs)) = &mut self.reader_crash => {                      let (send, recv) = oneshot::channel();                      let _ = self.writer_handle.send(WriteControl::Abort(send)).await;                      let (retry_msg, mut write_receiver) = tokio::select! { @@ -182,7 +187,8 @@ impl Supervisor {                                  update_sender,                                  supervisor_control,                                  write_handle, -                                tasks +                                tasks, +                                pending_iqs,                              );                          },                          Err(e) => { @@ -252,6 +258,7 @@ impl SupervisorHandle {          on_shutdown: oneshot::Sender<()>,          jid: Arc<Mutex<JID>>,          password: Arc<String>, +        pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      ) -> (WriteHandle, Self) {          let (command_sender, command_receiver) = mpsc::channel(20);          let (writer_error_sender, writer_error_receiver) = oneshot::channel(); @@ -267,6 +274,7 @@ impl SupervisorHandle {              update_sender.clone(),              command_sender.clone(),              write_handle.clone(), +            pending_iqs,          );          let actor = Supervisor::new( diff --git a/luz/src/connection/read.rs b/luz/src/connection/read.rs index edc6cdb..c1e37b4 100644 --- a/luz/src/connection/read.rs +++ b/luz/src/connection/read.rs @@ -1,5 +1,7 @@  use std::{ +    collections::HashMap,      ops::{Deref, DerefMut}, +    sync::Arc,      time::Duration,  }; @@ -7,7 +9,7 @@ use jabber::{connection::Tls, jabber_stream::bound_stream::BoundJabberReader};  use sqlx::SqlitePool;  use stanza::client::Stanza;  use tokio::{ -    sync::{mpsc, oneshot}, +    sync::{mpsc, oneshot, Mutex},      task::{JoinHandle, JoinSet},  }; @@ -28,6 +30,7 @@ pub struct Read {          JoinSet<()>,          mpsc::Sender<SupervisorCommand>,          WriteHandle, +        Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      )>,      db: SqlitePool,      update_sender: mpsc::Sender<UpdateMessage>, @@ -36,6 +39,8 @@ pub struct Read {      tasks: JoinSet<()>,      disconnecting: bool,      disconnect_timedout: oneshot::Receiver<()>, +    // TODO: use proper stanza ids +    pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,  }  impl Read { @@ -48,6 +53,7 @@ impl Read {              JoinSet<()>,              mpsc::Sender<SupervisorCommand>,              WriteHandle, +            Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,          )>,          db: SqlitePool,          update_sender: mpsc::Sender<UpdateMessage>, @@ -55,6 +61,7 @@ impl Read {          supervisor_control: mpsc::Sender<SupervisorCommand>,          write_handle: WriteHandle,          tasks: JoinSet<()>, +        pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      ) -> Self {          let (send, recv) = oneshot::channel();          Self { @@ -68,6 +75,7 @@ impl Read {              tasks,              disconnecting: false,              disconnect_timedout: recv, +            pending_iqs,          }      } @@ -91,7 +99,7 @@ impl Read {                              })                          },                          ReadControl::Abort(sender) => { -                            let _ = sender.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); +                            let _ = sender.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle, self.pending_iqs));                              break;                          },                      }; @@ -126,7 +134,7 @@ impl Read {                                  break;                              } else {                                  // AAAAAAAAAAAAAAAAAAAAA i should really just have this stored in the supervisor and not gaf bout passing these references around -                                let _ = self.on_crash.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle)); +                                let _ = self.on_crash.send((self.db, self.update_sender, self.tasks, self.supervisor_control, self.write_handle, self.pending_iqs));                              }                              break;                          }, @@ -134,6 +142,11 @@ impl Read {                  },                  else => break              } +            // when it aborts, must clear iq map no matter what +            let mut iqs = self.pending_iqs.lock().await; +            for (_id, sender) in iqs.drain() { +                let _ = sender.send(Err(Error::LostConnection)); +            }          }      }  } @@ -162,6 +175,7 @@ pub enum ReadControl {              JoinSet<()>,              mpsc::Sender<SupervisorCommand>,              WriteHandle, +            Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,          )>,      ),  } @@ -194,11 +208,13 @@ impl ReadControlHandle {              JoinSet<()>,              mpsc::Sender<SupervisorCommand>,              WriteHandle, +            Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,          )>,          db: SqlitePool,          sender: mpsc::Sender<UpdateMessage>,          supervisor_control: mpsc::Sender<SupervisorCommand>,          jabber_write: WriteHandle, +        pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      ) -> Self {          let (control_sender, control_receiver) = mpsc::channel(20); @@ -211,6 +227,7 @@ impl ReadControlHandle {              supervisor_control,              jabber_write,              JoinSet::new(), +            pending_iqs,          );          let handle = tokio::spawn(async move { actor.run().await }); @@ -228,12 +245,14 @@ impl ReadControlHandle {              JoinSet<()>,              mpsc::Sender<SupervisorCommand>,              WriteHandle, +            Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,          )>,          db: SqlitePool,          sender: mpsc::Sender<UpdateMessage>,          supervisor_control: mpsc::Sender<SupervisorCommand>,          jabber_write: WriteHandle,          tasks: JoinSet<()>, +        pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      ) -> Self {          let (control_sender, control_receiver) = mpsc::channel(20); @@ -246,6 +265,7 @@ impl ReadControlHandle {              supervisor_control,              jabber_write,              tasks, +            pending_iqs,          );          let handle = tokio::spawn(async move { actor.run().await }); diff --git a/luz/src/lib.rs b/luz/src/lib.rs index 333d8eb..9d8ea66 100644 --- a/luz/src/lib.rs +++ b/luz/src/lib.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc};  use connection::SupervisorSender;  use jabber::JID;  use sqlx::SqlitePool; -use stanza::roster; +use stanza::{client::Stanza, roster};  use tokio::{      sync::{mpsc, oneshot, Mutex},      task::JoinSet, @@ -22,6 +22,7 @@ pub struct Luz {      // TODO: use a dyn passwordprovider trait to avoid storing password in memory      password: Arc<String>,      connected: Arc<Mutex<Option<(WriteHandle, SupervisorHandle)>>>, +    pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      db: SqlitePool,      sender: mpsc::Sender<UpdateMessage>,      /// if connection was shut down due to e.g. server shutdown, supervisor must be able to mark client as disconnected @@ -50,6 +51,7 @@ impl Luz {              sender,              tasks: JoinSet::new(),              connection_supervisor_shutdown, +            pending_iqs: Arc::new(Mutex::new(HashMap::new())),          }      } @@ -87,6 +89,7 @@ impl Luz {                                                  shutdown_send,                                                  self.jid.clone(),                                                  self.password.clone(), +                                                self.pending_iqs.clone(),                                              );                                              self.connection_supervisor_shutdown = shutdown_recv;                                              *self.connected.lock().await = Some((writer, supervisor)); @@ -121,6 +124,7 @@ impl Luz {                                      self.db.clone(),                                      self.sender.clone(),                                      // TODO: iq hashmap +                                    self.pending_iqs.clone()                                  )),                                  None => self.tasks.spawn(msg.handle_offline(                                      self.jid.clone(), @@ -155,6 +159,7 @@ impl CommandMessage {          jid: Arc<Mutex<JID>>,          db: SqlitePool,          sender: mpsc::Sender<UpdateMessage>, +        pending_iqs: Arc<Mutex<HashMap<String, oneshot::Sender<Result<Stanza, Error>>>>>,      ) {          todo!()      } | 
