diff options
Diffstat (limited to 'jabber/src/jabber_stream/bound_stream.rs')
-rw-r--r-- | jabber/src/jabber_stream/bound_stream.rs | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs new file mode 100644 index 0000000..ca93421 --- /dev/null +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -0,0 +1,153 @@ +use std::pin::pin; +use std::pin::Pin; +use std::sync::Arc; + +use futures::{sink, stream, Sink, Stream}; +use peanuts::{Reader, Writer}; +use pin_project::pin_project; +use stanza::client::Stanza; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio::sync::Mutex; + +use crate::Error; + +use super::JabberStream; + +#[pin_project] +pub struct BoundJabberStream<R, W, S> +where + R: Stream, + W: Sink<Stanza>, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>, + writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>, + stream: R, + sink: W, +} + +impl<R, W, S> BoundJabberStream<R, W, S> +where + R: Stream, + W: Sink<Stanza>, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + // TODO: look into biased mutex, to close stream ASAP + pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { + if let Some(reader) = self.reader.lock().await.take() { + if let Some(writer) = self.writer.lock().await.take() { + // TODO: writer </stream:stream> + return Ok(JabberStream { reader, writer }); + } + } + return Err(Error::StreamClosed); + } +} + +pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} + +impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S> +where + R: Stream, + W: Sink<Stanza> + Unpin, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + type Error = <W as Sink<Stanza>>::Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), Self::Error>> { + let this = self.project(); + pin!(this.sink).poll_ready(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { + let this = self.project(); + pin!(this.sink).start_send(item) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), Self::Error>> { + let this = self.project(); + pin!(this.sink).poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Result<(), Self::Error>> { + let this = self.project(); + pin!(this.sink).poll_close(cx) + } +} + +impl<R, W, S> Stream for BoundJabberStream<R, W, S> +where + R: Stream + Unpin, + W: Sink<Stanza>, + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + type Item = <R as Stream>::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + let this = self.project(); + pin!(this.stream).poll_next(cx) + } +} + +impl<S> JabberStream<S> +where + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> { + let reader = Arc::new(Mutex::new(Some(self.reader))); + let writer = Arc::new(Mutex::new(Some(self.writer))); + let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { + write(writer, s).await + }); + let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); + BoundJabberStream { + sink, + stream, + writer, + reader, + } + } +} + +pub async fn write<W: AsyncWrite + Unpin + Send>( + writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, + stanza: Stanza, +) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> { + { + if let Some(writer) = writer.lock().await.as_mut() { + writer.write(&stanza).await?; + } else { + return Err(Error::StreamClosed); + } + } + Ok(writer) +} + +pub async fn read<R: AsyncRead + Unpin + Send>( + reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, +) -> Option<( + Result<Stanza, Error>, + Arc<Mutex<Option<Reader<ReadHalf<R>>>>>, +)> { + let stanza: Result<Stanza, Error>; + { + if let Some(reader) = reader.lock().await.as_mut() { + stanza = reader.read().await.map_err(|e| e.into()); + } else { + stanza = Err(Error::StreamClosed) + }; + } + Some((stanza, reader)) +} |