diff options
author | 2024-12-22 18:58:28 +0000 | |
---|---|---|
committer | 2024-12-22 18:58:28 +0000 | |
commit | 6385e43e8ca467e53c6a705a932016c5af75c3a2 (patch) | |
tree | f63fb7bd9a349f24b093ba4dd037c6ce7789f5ee /jabber | |
parent | 595d165479b8b12e456f39205d8433b822b07487 (diff) | |
download | luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.tar.gz luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.tar.bz2 luz-6385e43e8ca467e53c6a705a932016c5af75c3a2.zip |
implement sink and stream with tokio::spawn
Diffstat (limited to 'jabber')
-rw-r--r-- | jabber/src/client.rs | 211 | ||||
-rw-r--r-- | jabber/src/error.rs | 8 | ||||
-rw-r--r-- | jabber/src/jabber_stream.rs | 14 | ||||
-rw-r--r-- | jabber/src/jabber_stream/bound_stream.rs | 139 |
4 files changed, 286 insertions, 86 deletions
diff --git a/jabber/src/client.rs b/jabber/src/client.rs index c6cab07..32b8f6e 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -1,6 +1,12 @@ -use std::{pin::pin, sync::Arc, task::Poll}; +use std::{ + borrow::Borrow, + future::Future, + pin::pin, + sync::Arc, + task::{ready, Poll}, +}; -use futures::{Sink, Stream, StreamExt}; +use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ @@ -8,9 +14,11 @@ use stanza::{ sasl::Mechanisms, stream::{Feature, Features}, }; +use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, + jabber_stream::bound_stream::BoundJabberStream, Connection, Error, JabberStream, Result, JID, }; @@ -56,7 +64,7 @@ impl JabberClient { } } - pub(crate) fn inner(self) -> Result<JabberStream<Tls>> { + pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> { match self.connection { ConnectionState::Disconnected => return Err(Error::Disconnected), ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), @@ -64,21 +72,137 @@ impl JabberClient { } } - pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { - match &mut self.connection { - ConnectionState::Disconnected => return Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), - ConnectionState::Connected(jabber_stream) => { - Ok(jabber_stream.send_stanza(stanza).await?) - } - } + // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + // match &mut self.connection { + // ConnectionState::Disconnected => return Err(Error::Disconnected), + // ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), + // ConnectionState::Connected(jabber_stream) => { + // Ok(jabber_stream.send_stanza(stanza).await?) + // } + // } + // } +} + +impl Sink<Stanza> for JabberClient { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_ready_unpin(cx) + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + self.get_mut().connection.start_send_unpin(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_flush_unpin(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + self.get_mut().connection.poll_flush_unpin(cx) + } +} + +impl Stream for JabberClient { + type Item = Result<Stanza>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + self.get_mut().connection.poll_next_unpin(cx) } } pub enum ConnectionState { Disconnected, Connecting(Connecting), - Connected(JabberStream<Tls>), + Connected(BoundJabberStream<Tls>), +} + +impl Sink<Stanza> for ConnectionState { + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_ready_unpin(cx) + } + } + } + + fn start_send( + self: std::pin::Pin<&mut Self>, + item: Stanza, + ) -> std::result::Result<(), Self::Error> { + match self.get_mut() { + ConnectionState::Disconnected => Err(Error::Disconnected), + ConnectionState::Connecting(_connecting) => Err(Error::Connecting), + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.start_send_unpin(item) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_flush_unpin(cx) + } + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<std::result::Result<(), Self::Error>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_close_unpin(cx) + } + } + } +} + +impl Stream for ConnectionState { + type Item = Result<Stanza>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + match self.get_mut() { + ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), + ConnectionState::Connecting(_connecting) => Poll::Pending, + ConnectionState::Connected(bound_jabber_stream) => { + bound_jabber_stream.poll_next_unpin(cx) + } + } + } } impl ConnectionState { @@ -150,7 +274,9 @@ impl ConnectionState { )) } Connecting::Bind(jabber_stream) => { - self = ConnectionState::Connected(jabber_stream.bind(jid).await?) + self = ConnectionState::Connected( + jabber_stream.bind(jid).await?.to_bound_jabber(), + ) } }, connected => return Ok(connected), @@ -194,11 +320,20 @@ pub enum InsecureConnecting { #[cfg(test)] mod tests { - use std::time::Duration; + use std::{sync::Arc, time::Duration}; use super::JabberClient; + use futures::{SinkExt, StreamExt}; + use stanza::{ + client::{ + iq::{Iq, IqType, Query}, + Stanza, + }, + xep_0199::Ping, + }; use test_log::test; - use tokio::time::sleep; + use tokio::{sync::Mutex, time::sleep}; + use tracing::info; #[test(tokio::test)] async fn login() { @@ -206,4 +341,50 @@ mod tests { client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } + + #[test(tokio::test)] + async fn ping_parallel() { + let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); + client.connect().await.unwrap(); + sleep(Duration::from_secs(5)).await; + let jid = client.jid.clone(); + let server = client.server.clone(); + let mut client = Arc::new(Mutex::new(client)); + + tokio::join!( + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s1".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + let mut client = client.lock().await; + client + .send(Stanza::Iq(Iq { + from: Some(jid.clone()), + id: "c2s2".to_string(), + to: Some(server.clone().try_into().unwrap()), + r#type: IqType::Get, + lang: None, + query: Some(Query::Ping(Ping)), + errors: Vec::new(), + })) + .await; + }, + async { + while let Some(stanza) = client.lock().await.next().await { + info!("{:#?}", stanza); + } + } + ); + } } diff --git a/jabber/src/error.rs b/jabber/src/error.rs index 6671fe6..902061e 100644 --- a/jabber/src/error.rs +++ b/jabber/src/error.rs @@ -5,6 +5,7 @@ use rsasl::mechname::MechanismNameError; use stanza::client::error::Error as ClientError; use stanza::sasl::Failure; use stanza::stream::Error as StreamError; +use tokio::task::JoinError; #[derive(Debug)] pub enum Error { @@ -28,6 +29,7 @@ pub enum Error { MissingError, Disconnected, Connecting, + JoinError(JoinError), } #[derive(Debug)] @@ -42,6 +44,12 @@ impl From<rsasl::prelude::SASLError> for Error { } } +impl From<JoinError> for Error { + fn from(e: JoinError) -> Self { + Self::JoinError(e) + } +} + impl From<peanuts::DeserializeError> for Error { fn from(e: peanuts::DeserializeError) -> Self { Error::Deserialization(e) diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index d981f8f..89890a8 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -27,7 +27,7 @@ pub mod bound_stream; // open stream (streams started) pub struct JabberStream<S> { reader: Reader<ReadHalf<S>>, - writer: Writer<WriteHalf<S>>, + pub(crate) writer: Writer<WriteHalf<S>>, } impl<S> JabberStream<S> @@ -368,12 +368,12 @@ mod tests { async fn sink() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); - let stream = client.inner().unwrap(); - let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { - stream.writer.write(&stanza).await?; - Ok::<JabberStream<Tls>, Error>(stream) - }); - todo!() + // let stream = client.inner().unwrap(); + // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { + // stream.writer.write(&stanza).await?; + // Ok::<JabberStream<Tls>, Error>(stream) + // }); + // todo!() // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // .await // .unwrap() diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index ca93421..c0d67b0 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,70 +1,71 @@ +use std::future::ready; use std::pin::pin; use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; +use futures::ready; +use futures::FutureExt; use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; use pin_project::pin_project; use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::Mutex; +use tokio::task::JoinHandle; use crate::Error; use super::JabberStream; #[pin_project] -pub struct BoundJabberStream<R, W, S> +pub struct BoundJabberStream<S> where - R: Stream, - W: Sink<Stanza>, S: AsyncWrite + AsyncRead + Unpin + Send, { - reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>, - writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>, - stream: R, - sink: W, + reader: Arc<Mutex<Reader<ReadHalf<S>>>>, + writer: Arc<Mutex<Writer<WriteHalf<S>>>>, + write_handle: Option<JoinHandle<Result<(), Error>>>, + read_handle: Option<JoinHandle<Result<Stanza, Error>>>, } -impl<R, W, S> BoundJabberStream<R, W, S> +impl<S> BoundJabberStream<S> where - R: Stream, - W: Sink<Stanza>, S: AsyncWrite + AsyncRead + Unpin + Send, { // TODO: look into biased mutex, to close stream ASAP - pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { - if let Some(reader) = self.reader.lock().await.take() { - if let Some(writer) = self.writer.lock().await.take() { - // TODO: writer </stream:stream> - return Ok(JabberStream { reader, writer }); - } - } - return Err(Error::StreamClosed); - } + // TODO: put into connection + // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { + // let reader = self.reader.lock().await.into_self(); + // let writer = self.writer.lock().await.into_self(); + // // TODO: writer </stream:stream> + // return Ok(JabberStream { reader, writer }); + // } } pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} -impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S> +impl<S> Sink<Stanza> for BoundJabberStream<S> where - R: Stream, - W: Sink<Stanza> + Unpin, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Error = <W as Sink<Stanza>>::Error; + type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Result<(), Self::Error>> { - let this = self.project(); - pin!(this.sink).poll_ready(cx) + self.poll_flush(cx) } fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { let this = self.project(); - pin!(this.sink).start_send(item) + if let Some(_write_handle) = this.write_handle { + panic!("start_send called without poll_ready") + } else { + *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); + Ok(()) + } } fn poll_flush( @@ -72,32 +73,55 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Result<(), Self::Error>> { let this = self.project(); - pin!(this.sink).poll_flush(cx) + Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { + match ready!(join_handle.poll_unpin(cx)) { + Ok(state) => { + *this.write_handle = None; + state + } + Err(err) => { + *this.write_handle = None; + Err(err.into()) + } + } + } else { + Ok(()) + }) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Result<(), Self::Error>> { - let this = self.project(); - pin!(this.sink).poll_close(cx) + self.poll_flush(cx) } } -impl<R, W, S> Stream for BoundJabberStream<R, W, S> +impl<S> Stream for BoundJabberStream<S> where - R: Stream + Unpin, - W: Sink<Stanza>, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Item = <R as Stream>::Item; + type Item = Result<Stanza, Error>; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll<Option<Self::Item>> { let this = self.project(); - pin!(this.stream).poll_next(cx) + + loop { + if let Some(join_handle) = this.read_handle.as_mut() { + let stanza = ready!(join_handle.poll_unpin(cx)); + if let Ok(item) = stanza { + *this.read_handle = None; + return Poll::Ready(Some(item)); + } else if let Err(err) = stanza { + return Poll::Ready(Some(Err(err.into()))); + } + } else { + *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) + } + } } } @@ -105,49 +129,36 @@ impl<S> JabberStream<S> where S: AsyncWrite + AsyncRead + Unpin + Send, { - pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> { - let reader = Arc::new(Mutex::new(Some(self.reader))); - let writer = Arc::new(Mutex::new(Some(self.writer))); - let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { - write(writer, s).await - }); - let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); + pub fn to_bound_jabber(self) -> BoundJabberStream<S> { + let reader = Arc::new(Mutex::new(self.reader)); + let writer = Arc::new(Mutex::new(self.writer)); BoundJabberStream { - sink, - stream, writer, reader, + write_handle: None, + read_handle: None, } } } pub async fn write<W: AsyncWrite + Unpin + Send>( - writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, + writer: Arc<Mutex<Writer<WriteHalf<W>>>>, stanza: Stanza, -) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> { +) -> Result<(), Error> { { - if let Some(writer) = writer.lock().await.as_mut() { - writer.write(&stanza).await?; - } else { - return Err(Error::StreamClosed); - } + let mut writer = writer.lock().await; + writer.write(&stanza).await?; } - Ok(writer) + Ok(()) } pub async fn read<R: AsyncRead + Unpin + Send>( - reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, -) -> Option<( - Result<Stanza, Error>, - Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, -)> { + reader: Arc<Mutex<Reader<ReadHalf<R>>>>, +) -> Result<Stanza, Error> { let stanza: Result<Stanza, Error>; { - if let Some(reader) = reader.lock().await.as_mut() { - stanza = reader.read().await.map_err(|e| e.into()); - } else { - stanza = Err(Error::StreamClosed) - }; + let mut reader = reader.lock().await; + stanza = reader.read().await.map_err(|e| e.into()); } - Some((stanza, reader)) + stanza } |