diff options
author | 2024-12-22 18:58:28 +0000 | |
---|---|---|
committer | 2024-12-22 18:58:28 +0000 | |
commit | 6385e43e8ca467e53c6a705a932016c5af75c3a2 (patch) | |
tree | f63fb7bd9a349f24b093ba4dd037c6ce7789f5ee /jabber/src/client.rs | |
parent | 595d165479b8b12e456f39205d8433b822b07487 (diff) | |
download | luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.tar.gz luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.tar.bz2 luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.zip |
implement sink and stream with tokio::spawn
Diffstat (limited to 'jabber/src/client.rs')
-rw-r--r-- | jabber/src/client.rs | 211 |
1 files changed, 196 insertions, 15 deletions
diff --git a/jabber/src/client.rs b/jabber/src/client.rs index c6cab07..32b8f6e 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -1,6 +1,12 @@ -use std::{pin::pin, sync::Arc, task::Poll}; +use std::{ + borrow::Borrow, + future::Future, + pin::pin, + sync::Arc, + task::{ready, Poll}, +}; -use futures::{Sink, Stream, StreamExt}; +use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ @@ -8,9 +14,11 @@ use stanza::{ sasl::Mechanisms, stream::{Feature, Features}, }; +use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, + jabber_stream::bound_stream::BoundJabberStream, Connection, Error, JabberStream, Result, JID, }; @@ -56,7 +64,7 @@ impl JabberClient { } } - pub(crate) fn inner(self) -> Result<JabberStream<Tls>> { + pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> { match self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), @@ -64,21 +72,137 @@ impl JabberClient { } } - pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { - match &mut self.connection { - ConnectionState::Disconnected => return Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), - ConnectionState::Connected(jabber_stream) => { - Ok(jabber_stream.send_stanza(stanza).await?) - } - } + // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + // match &mut self.connection { + // ConnectionState::Disconnected => return Err(Error::Disconnected), + // ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), + // ConnectionState::Connected(jabber_stream) => { + // Ok(jabber_stream.send_stanza(stanza).await?) + // } + // } + // } +} + +impl Sink<Stanza> for JabberClient { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_ready_unpin(cx) + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + self.get_mut().connection.start_send_unpin(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_flush_unpin(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_flush_unpin(cx) + } +} + +impl Stream for JabberClient { + type Item = Result<Stanza>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + self.get_mut().connection.poll_next_unpin(cx) } } pub enum ConnectionState { Disconnected, Connecting(Connecting), - Connected(JabberStream<Tls>), + Connected(BoundJabberStream<Tls>), +} + +impl Sink<Stanza> for ConnectionState { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_ready_unpin(cx) + } + } + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + match self.get_mut() { + ConnectionState::Disconnected => Err(Error::Disconnected), + ConnectionState::Connecting(_connecting) => Err(Error::Connecting), + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.start_send_unpin(item) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_flush_unpin(cx) + } + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_close_unpin(cx) + } + } + } +} + +impl Stream for ConnectionState { + type Item = Result<Stanza>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_next_unpin(cx) + } + } + } } impl ConnectionState { @@ -150,7 +274,9 @@ impl ConnectionState { )) } Connecting::Bind(jabber_stream) => { - self = ConnectionState::Connected(jabber_stream.bind(jid).await?) + self = ConnectionState::Connected( + jabber_stream.bind(jid).await?.to_bound_jabber(), + ) } }, connected => return Ok(connected), @@ -194,11 +320,20 @@ pub enum InsecureConnecting { #[cfg(test)] mod tests { - use std::time::Duration; + use std::{sync::Arc, time::Duration}; use super::JabberClient; + use futures::{SinkExt, StreamExt}; + use stanza::{ + client::{ + iq::{Iq, IqType, Query}, + Stanza, + }, + xep_0199::Ping, + }; use test_log::test; - use tokio::time::sleep; + use tokio::{sync::Mutex, time::sleep}; + use tracing::info; #[test(tokio::test)] async fn login() { @@ -206,4 +341,50 @@ mod tests { client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } + + #[test(tokio::test)] + async fn ping_parallel() { + let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); + client.connect().await.unwrap(); + sleep(Duration::from_secs(5)).await; + let jid = client.jid.clone(); + let server = client.server.clone(); + let mut client = Arc::new(Mutex::new(client)); + + tokio::join!( + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s1".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s2".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + while let Some(stanza) = client.lock().await.next().await { + info!("{:#?}", stanza); + } + } + ); + } } |