aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/jabber_stream/bound_stream.rs
blob: ca9342107f0250d097b0f72ed0a74a07fbbc69a5 (plain) (tree)
























































































































































                                                                                            
use std::pin::pin;
use std::pin::Pin;
use std::sync::Arc;

use futures::{sink, stream, Sink, Stream};
use peanuts::{Reader, Writer};
use pin_project::pin_project;
use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::Mutex;

use crate::Error;

use super::JabberStream;

#[pin_project]
pub struct BoundJabberStream<R, W, S>
where
    R: Stream,
    W: Sink<Stanza>,
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
    writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
    stream: R,
    sink: W,
}

impl<R, W, S> BoundJabberStream<R, W, S>
where
    R: Stream,
    W: Sink<Stanza>,
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    // TODO: look into biased mutex, to close stream ASAP
    pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
        if let Some(reader) = self.reader.lock().await.take() {
            if let Some(writer) = self.writer.lock().await.take() {
                // TODO: writer </stream:stream>
                return Ok(JabberStream { reader, writer });
            }
        }
        return Err(Error::StreamClosed);
    }
}

pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}

impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
where
    R: Stream,
    W: Sink<Stanza> + Unpin,
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    type Error = <W as Sink<Stanza>>::Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        let this = self.project();
        pin!(this.sink).poll_ready(cx)
    }

    fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
        let this = self.project();
        pin!(this.sink).start_send(item)
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        let this = self.project();
        pin!(this.sink).poll_flush(cx)
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        let this = self.project();
        pin!(this.sink).poll_close(cx)
    }
}

impl<R, W, S> Stream for BoundJabberStream<R, W, S>
where
    R: Stream + Unpin,
    W: Sink<Stanza>,
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    type Item = <R as Stream>::Item;

    fn poll_next(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        let this = self.project();
        pin!(this.stream).poll_next(cx)
    }
}

impl<S> JabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
        let reader = Arc::new(Mutex::new(Some(self.reader)));
        let writer = Arc::new(Mutex::new(Some(self.writer)));
        let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
            write(writer, s).await
        });
        let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
        BoundJabberStream {
            sink,
            stream,
            writer,
            reader,
        }
    }
}

pub async fn write<W: AsyncWrite + Unpin + Send>(
    writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
    stanza: Stanza,
) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
    {
        if let Some(writer) = writer.lock().await.as_mut() {
            writer.write(&stanza).await?;
        } else {
            return Err(Error::StreamClosed);
        }
    }
    Ok(writer)
}

pub async fn read<R: AsyncRead + Unpin + Send>(
    reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
) -> Option<(
    Result<Stanza, Error>,
    Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
)> {
    let stanza: Result<Stanza, Error>;
    {
        if let Some(reader) = reader.lock().await.as_mut() {
            stanza = reader.read().await.map_err(|e| e.into());
        } else {
            stanza = Err(Error::StreamClosed)
        };
    }
    Some((stanza, reader))
}