aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/jabber_stream/bound_stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'jabber/src/jabber_stream/bound_stream.rs')
-rw-r--r--jabber/src/jabber_stream/bound_stream.rs153
1 files changed, 153 insertions, 0 deletions
diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs
new file mode 100644
index 0000000..ca93421
--- /dev/null
+++ b/jabber/src/jabber_stream/bound_stream.rs
@@ -0,0 +1,153 @@
+use std::pin::pin;
+use std::pin::Pin;
+use std::sync::Arc;
+
+use futures::{sink, stream, Sink, Stream};
+use peanuts::{Reader, Writer};
+use pin_project::pin_project;
+use stanza::client::Stanza;
+use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
+use tokio::sync::Mutex;
+
+use crate::Error;
+
+use super::JabberStream;
+
+#[pin_project]
+pub struct BoundJabberStream<R, W, S>
+where
+ R: Stream,
+ W: Sink<Stanza>,
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
+ writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
+ stream: R,
+ sink: W,
+}
+
+impl<R, W, S> BoundJabberStream<R, W, S>
+where
+ R: Stream,
+ W: Sink<Stanza>,
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ // TODO: look into biased mutex, to close stream ASAP
+ pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
+ if let Some(reader) = self.reader.lock().await.take() {
+ if let Some(writer) = self.writer.lock().await.take() {
+ // TODO: writer </stream:stream>
+ return Ok(JabberStream { reader, writer });
+ }
+ }
+ return Err(Error::StreamClosed);
+ }
+}
+
+pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
+
+impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
+where
+ R: Stream,
+ W: Sink<Stanza> + Unpin,
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ type Error = <W as Sink<Stanza>>::Error;
+
+ fn poll_ready(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), Self::Error>> {
+ let this = self.project();
+ pin!(this.sink).poll_ready(cx)
+ }
+
+ fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
+ let this = self.project();
+ pin!(this.sink).start_send(item)
+ }
+
+ fn poll_flush(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), Self::Error>> {
+ let this = self.project();
+ pin!(this.sink).poll_flush(cx)
+ }
+
+ fn poll_close(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Result<(), Self::Error>> {
+ let this = self.project();
+ pin!(this.sink).poll_close(cx)
+ }
+}
+
+impl<R, W, S> Stream for BoundJabberStream<R, W, S>
+where
+ R: Stream + Unpin,
+ W: Sink<Stanza>,
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ type Item = <R as Stream>::Item;
+
+ fn poll_next(
+ self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ let this = self.project();
+ pin!(this.stream).poll_next(cx)
+ }
+}
+
+impl<S> JabberStream<S>
+where
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
+ let reader = Arc::new(Mutex::new(Some(self.reader)));
+ let writer = Arc::new(Mutex::new(Some(self.writer)));
+ let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
+ write(writer, s).await
+ });
+ let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
+ BoundJabberStream {
+ sink,
+ stream,
+ writer,
+ reader,
+ }
+ }
+}
+
+pub async fn write<W: AsyncWrite + Unpin + Send>(
+ writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
+ stanza: Stanza,
+) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
+ {
+ if let Some(writer) = writer.lock().await.as_mut() {
+ writer.write(&stanza).await?;
+ } else {
+ return Err(Error::StreamClosed);
+ }
+ }
+ Ok(writer)
+}
+
+pub async fn read<R: AsyncRead + Unpin + Send>(
+ reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
+) -> Option<(
+ Result<Stanza, Error>,
+ Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
+)> {
+ let stanza: Result<Stanza, Error>;
+ {
+ if let Some(reader) = reader.lock().await.as_mut() {
+ stanza = reader.read().await.map_err(|e| e.into());
+ } else {
+ stanza = Err(Error::StreamClosed)
+ };
+ }
+ Some((stanza, reader))
+}