diff options
Diffstat (limited to '')
| -rw-r--r-- | jabber/src/client.rs | 211 | ||||
| -rw-r--r-- | jabber/src/error.rs | 8 | ||||
| -rw-r--r-- | jabber/src/jabber_stream.rs | 14 | ||||
| -rw-r--r-- | jabber/src/jabber_stream/bound_stream.rs | 139 | 
4 files changed, 286 insertions, 86 deletions
| diff --git a/jabber/src/client.rs b/jabber/src/client.rs index c6cab07..32b8f6e 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -1,6 +1,12 @@ -use std::{pin::pin, sync::Arc, task::Poll}; +use std::{ +    borrow::Borrow, +    future::Future, +    pin::pin, +    sync::Arc, +    task::{ready, Poll}, +}; -use futures::{Sink, Stream, StreamExt}; +use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};  use jid::ParseError;  use rsasl::config::SASLConfig;  use stanza::{ @@ -8,9 +14,11 @@ use stanza::{      sasl::Mechanisms,      stream::{Feature, Features},  }; +use tokio::sync::Mutex;  use crate::{      connection::{Tls, Unencrypted}, +    jabber_stream::bound_stream::BoundJabberStream,      Connection, Error, JabberStream, Result, JID,  }; @@ -56,7 +64,7 @@ impl JabberClient {          }      } -    pub(crate) fn inner(self) -> Result<JabberStream<Tls>> { +    pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {          match self.connection {              ConnectionState::Disconnected => return Err(Error::Disconnected),              ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), @@ -64,21 +72,137 @@ impl JabberClient {          }      } -    pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { -        match &mut self.connection { -            ConnectionState::Disconnected => return Err(Error::Disconnected), -            ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), -            ConnectionState::Connected(jabber_stream) => { -                Ok(jabber_stream.send_stanza(stanza).await?) -            } -        } +    // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { +    //     match &mut self.connection { +    //         ConnectionState::Disconnected => return Err(Error::Disconnected), +    //         ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), +    //         ConnectionState::Connected(jabber_stream) => { +    //             Ok(jabber_stream.send_stanza(stanza).await?) +    //         } +    //     } +    // } +} + +impl Sink<Stanza> for JabberClient { +    type Error = Error; + +    fn poll_ready( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        self.get_mut().connection.poll_ready_unpin(cx) +    } + +    fn start_send( +        self: std::pin::Pin<&mut Self>, +        item: Stanza, +    ) -> std::result::Result<(), Self::Error> { +        self.get_mut().connection.start_send_unpin(item) +    } + +    fn poll_flush( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        self.get_mut().connection.poll_flush_unpin(cx) +    } + +    fn poll_close( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        self.get_mut().connection.poll_flush_unpin(cx) +    } +} + +impl Stream for JabberClient { +    type Item = Result<Stanza>; + +    fn poll_next( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<Option<Self::Item>> { +        self.get_mut().connection.poll_next_unpin(cx)      }  }  pub enum ConnectionState {      Disconnected,      Connecting(Connecting), -    Connected(JabberStream<Tls>), +    Connected(BoundJabberStream<Tls>), +} + +impl Sink<Stanza> for ConnectionState { +    type Error = Error; + +    fn poll_ready( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        match self.get_mut() { +            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), +            ConnectionState::Connecting(_connecting) => Poll::Pending, +            ConnectionState::Connected(bound_jabber_stream) => { +                bound_jabber_stream.poll_ready_unpin(cx) +            } +        } +    } + +    fn start_send( +        self: std::pin::Pin<&mut Self>, +        item: Stanza, +    ) -> std::result::Result<(), Self::Error> { +        match self.get_mut() { +            ConnectionState::Disconnected => Err(Error::Disconnected), +            ConnectionState::Connecting(_connecting) => Err(Error::Connecting), +            ConnectionState::Connected(bound_jabber_stream) => { +                bound_jabber_stream.start_send_unpin(item) +            } +        } +    } + +    fn poll_flush( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        match self.get_mut() { +            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), +            ConnectionState::Connecting(_connecting) => Poll::Pending, +            ConnectionState::Connected(bound_jabber_stream) => { +                bound_jabber_stream.poll_flush_unpin(cx) +            } +        } +    } + +    fn poll_close( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<std::result::Result<(), Self::Error>> { +        match self.get_mut() { +            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), +            ConnectionState::Connecting(_connecting) => Poll::Pending, +            ConnectionState::Connected(bound_jabber_stream) => { +                bound_jabber_stream.poll_close_unpin(cx) +            } +        } +    } +} + +impl Stream for ConnectionState { +    type Item = Result<Stanza>; + +    fn poll_next( +        self: std::pin::Pin<&mut Self>, +        cx: &mut std::task::Context<'_>, +    ) -> Poll<Option<Self::Item>> { +        match self.get_mut() { +            ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), +            ConnectionState::Connecting(_connecting) => Poll::Pending, +            ConnectionState::Connected(bound_jabber_stream) => { +                bound_jabber_stream.poll_next_unpin(cx) +            } +        } +    }  }  impl ConnectionState { @@ -150,7 +274,9 @@ impl ConnectionState {                          ))                      }                      Connecting::Bind(jabber_stream) => { -                        self = ConnectionState::Connected(jabber_stream.bind(jid).await?) +                        self = ConnectionState::Connected( +                            jabber_stream.bind(jid).await?.to_bound_jabber(), +                        )                      }                  },                  connected => return Ok(connected), @@ -194,11 +320,20 @@ pub enum InsecureConnecting {  #[cfg(test)]  mod tests { -    use std::time::Duration; +    use std::{sync::Arc, time::Duration};      use super::JabberClient; +    use futures::{SinkExt, StreamExt}; +    use stanza::{ +        client::{ +            iq::{Iq, IqType, Query}, +            Stanza, +        }, +        xep_0199::Ping, +    };      use test_log::test; -    use tokio::time::sleep; +    use tokio::{sync::Mutex, time::sleep}; +    use tracing::info;      #[test(tokio::test)]      async fn login() { @@ -206,4 +341,50 @@ mod tests {          client.connect().await.unwrap();          sleep(Duration::from_secs(5)).await      } + +    #[test(tokio::test)] +    async fn ping_parallel() { +        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); +        client.connect().await.unwrap(); +        sleep(Duration::from_secs(5)).await; +        let jid = client.jid.clone(); +        let server = client.server.clone(); +        let mut client = Arc::new(Mutex::new(client)); + +        tokio::join!( +            async { +                let mut client = client.lock().await; +                client +                    .send(Stanza::Iq(Iq { +                        from: Some(jid.clone()), +                        id: "c2s1".to_string(), +                        to: Some(server.clone().try_into().unwrap()), +                        r#type: IqType::Get, +                        lang: None, +                        query: Some(Query::Ping(Ping)), +                        errors: Vec::new(), +                    })) +                    .await; +            }, +            async { +                let mut client = client.lock().await; +                client +                    .send(Stanza::Iq(Iq { +                        from: Some(jid.clone()), +                        id: "c2s2".to_string(), +                        to: Some(server.clone().try_into().unwrap()), +                        r#type: IqType::Get, +                        lang: None, +                        query: Some(Query::Ping(Ping)), +                        errors: Vec::new(), +                    })) +                    .await; +            }, +            async { +                while let Some(stanza) = client.lock().await.next().await { +                    info!("{:#?}", stanza); +                } +            } +        ); +    }  } diff --git a/jabber/src/error.rs b/jabber/src/error.rs index 6671fe6..902061e 100644 --- a/jabber/src/error.rs +++ b/jabber/src/error.rs @@ -5,6 +5,7 @@ use rsasl::mechname::MechanismNameError;  use stanza::client::error::Error as ClientError;  use stanza::sasl::Failure;  use stanza::stream::Error as StreamError; +use tokio::task::JoinError;  #[derive(Debug)]  pub enum Error { @@ -28,6 +29,7 @@ pub enum Error {      MissingError,      Disconnected,      Connecting, +    JoinError(JoinError),  }  #[derive(Debug)] @@ -42,6 +44,12 @@ impl From<rsasl::prelude::SASLError> for Error {      }  } +impl From<JoinError> for Error { +    fn from(e: JoinError) -> Self { +        Self::JoinError(e) +    } +} +  impl From<peanuts::DeserializeError> for Error {      fn from(e: peanuts::DeserializeError) -> Self {          Error::Deserialization(e) diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index d981f8f..89890a8 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -27,7 +27,7 @@ pub mod bound_stream;  // open stream (streams started)  pub struct JabberStream<S> {      reader: Reader<ReadHalf<S>>, -    writer: Writer<WriteHalf<S>>, +    pub(crate) writer: Writer<WriteHalf<S>>,  }  impl<S> JabberStream<S> @@ -368,12 +368,12 @@ mod tests {      async fn sink() {          let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();          client.connect().await.unwrap(); -        let stream = client.inner().unwrap(); -        let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { -            stream.writer.write(&stanza).await?; -            Ok::<JabberStream<Tls>, Error>(stream) -        }); -        todo!() +        // let stream = client.inner().unwrap(); +        // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { +        //     stream.writer.write(&stanza).await?; +        //     Ok::<JabberStream<Tls>, Error>(stream) +        // }); +        // todo!()          // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())          //     .await          //     .unwrap() diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index ca93421..c0d67b0 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,70 +1,71 @@ +use std::future::ready;  use std::pin::pin;  use std::pin::Pin;  use std::sync::Arc; +use std::task::Poll; +use futures::ready; +use futures::FutureExt;  use futures::{sink, stream, Sink, Stream};  use peanuts::{Reader, Writer};  use pin_project::pin_project;  use stanza::client::Stanza;  use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};  use tokio::sync::Mutex; +use tokio::task::JoinHandle;  use crate::Error;  use super::JabberStream;  #[pin_project] -pub struct BoundJabberStream<R, W, S> +pub struct BoundJabberStream<S>  where -    R: Stream, -    W: Sink<Stanza>,      S: AsyncWrite + AsyncRead + Unpin + Send,  { -    reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>, -    writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>, -    stream: R, -    sink: W, +    reader: Arc<Mutex<Reader<ReadHalf<S>>>>, +    writer: Arc<Mutex<Writer<WriteHalf<S>>>>, +    write_handle: Option<JoinHandle<Result<(), Error>>>, +    read_handle: Option<JoinHandle<Result<Stanza, Error>>>,  } -impl<R, W, S> BoundJabberStream<R, W, S> +impl<S> BoundJabberStream<S>  where -    R: Stream, -    W: Sink<Stanza>,      S: AsyncWrite + AsyncRead + Unpin + Send,  {      // TODO: look into biased mutex, to close stream ASAP -    pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { -        if let Some(reader) = self.reader.lock().await.take() { -            if let Some(writer) = self.writer.lock().await.take() { -                // TODO: writer </stream:stream> -                return Ok(JabberStream { reader, writer }); -            } -        } -        return Err(Error::StreamClosed); -    } +    // TODO: put into connection +    // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { +    //     let reader = self.reader.lock().await.into_self(); +    //     let writer = self.writer.lock().await.into_self(); +    //     // TODO: writer </stream:stream> +    //     return Ok(JabberStream { reader, writer }); +    // }  }  pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} -impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S> +impl<S> Sink<Stanza> for BoundJabberStream<S>  where -    R: Stream, -    W: Sink<Stanza> + Unpin, -    S: AsyncWrite + AsyncRead + Unpin + Send, +    S: AsyncWrite + AsyncRead + Unpin + Send + 'static,  { -    type Error = <W as Sink<Stanza>>::Error; +    type Error = Error;      fn poll_ready(          self: std::pin::Pin<&mut Self>,          cx: &mut std::task::Context<'_>,      ) -> std::task::Poll<Result<(), Self::Error>> { -        let this = self.project(); -        pin!(this.sink).poll_ready(cx) +        self.poll_flush(cx)      }      fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {          let this = self.project(); -        pin!(this.sink).start_send(item) +        if let Some(_write_handle) = this.write_handle { +            panic!("start_send called without poll_ready") +        } else { +            *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); +            Ok(()) +        }      }      fn poll_flush( @@ -72,32 +73,55 @@ where          cx: &mut std::task::Context<'_>,      ) -> std::task::Poll<Result<(), Self::Error>> {          let this = self.project(); -        pin!(this.sink).poll_flush(cx) +        Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { +            match ready!(join_handle.poll_unpin(cx)) { +                Ok(state) => { +                    *this.write_handle = None; +                    state +                } +                Err(err) => { +                    *this.write_handle = None; +                    Err(err.into()) +                } +            } +        } else { +            Ok(()) +        })      }      fn poll_close(          self: std::pin::Pin<&mut Self>,          cx: &mut std::task::Context<'_>,      ) -> std::task::Poll<Result<(), Self::Error>> { -        let this = self.project(); -        pin!(this.sink).poll_close(cx) +        self.poll_flush(cx)      }  } -impl<R, W, S> Stream for BoundJabberStream<R, W, S> +impl<S> Stream for BoundJabberStream<S>  where -    R: Stream + Unpin, -    W: Sink<Stanza>, -    S: AsyncWrite + AsyncRead + Unpin + Send, +    S: AsyncWrite + AsyncRead + Unpin + Send + 'static,  { -    type Item = <R as Stream>::Item; +    type Item = Result<Stanza, Error>;      fn poll_next(          self: Pin<&mut Self>,          cx: &mut std::task::Context<'_>,      ) -> std::task::Poll<Option<Self::Item>> {          let this = self.project(); -        pin!(this.stream).poll_next(cx) + +        loop { +            if let Some(join_handle) = this.read_handle.as_mut() { +                let stanza = ready!(join_handle.poll_unpin(cx)); +                if let Ok(item) = stanza { +                    *this.read_handle = None; +                    return Poll::Ready(Some(item)); +                } else if let Err(err) = stanza { +                    return Poll::Ready(Some(Err(err.into()))); +                } +            } else { +                *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) +            } +        }      }  } @@ -105,49 +129,36 @@ impl<S> JabberStream<S>  where      S: AsyncWrite + AsyncRead + Unpin + Send,  { -    pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> { -        let reader = Arc::new(Mutex::new(Some(self.reader))); -        let writer = Arc::new(Mutex::new(Some(self.writer))); -        let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { -            write(writer, s).await -        }); -        let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); +    pub fn to_bound_jabber(self) -> BoundJabberStream<S> { +        let reader = Arc::new(Mutex::new(self.reader)); +        let writer = Arc::new(Mutex::new(self.writer));          BoundJabberStream { -            sink, -            stream,              writer,              reader, +            write_handle: None, +            read_handle: None,          }      }  }  pub async fn write<W: AsyncWrite + Unpin + Send>( -    writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, +    writer: Arc<Mutex<Writer<WriteHalf<W>>>>,      stanza: Stanza, -) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> { +) -> Result<(), Error> {      { -        if let Some(writer) = writer.lock().await.as_mut() { -            writer.write(&stanza).await?; -        } else { -            return Err(Error::StreamClosed); -        } +        let mut writer = writer.lock().await; +        writer.write(&stanza).await?;      } -    Ok(writer) +    Ok(())  }  pub async fn read<R: AsyncRead + Unpin + Send>( -    reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, -) -> Option<( -    Result<Stanza, Error>, -    Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, -)> { +    reader: Arc<Mutex<Reader<ReadHalf<R>>>>, +) -> Result<Stanza, Error> {      let stanza: Result<Stanza, Error>;      { -        if let Some(reader) = reader.lock().await.as_mut() { -            stanza = reader.read().await.map_err(|e| e.into()); -        } else { -            stanza = Err(Error::StreamClosed) -        }; +        let mut reader = reader.lock().await; +        stanza = reader.read().await.map_err(|e| e.into());      } -    Some((stanza, reader)) +    stanza  } | 
