aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/jabber_stream/bound_stream.rs
blob: 627158a3ef2b30167097fef6c5fa615f162932c2 (plain) (tree)
1
2
3
4
5
6
7
8
                       


                   
                    
 

                       





                                                            
                            





                        
                               
     

                                             



                                                           

 
                            
     


                                                         






                                                                          



                                                                     
                                             
     
                                                       
 
                       




                                                   
                           



                                                                                            


                                                          
                                                                                    


                                                                                      






                                                   













                                                                           





                                                   
                           


     
                                       
     
                                                       
 
                                      





                                              













                                                                                 






                                             


                                                          
                           

                   

                               




                                                 
                                             
                   
                        
     

                                             
     
          


                                               

                                            

                                      

                                                           
     
          
 
use std::future::ready;
use std::pin::pin;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;

use futures::ready;
use futures::FutureExt;
use futures::{sink, stream, Sink, Stream};
use peanuts::{Reader, Writer};
use pin_project::pin_project;
use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;

use crate::Error;

use super::JabberStream;

#[pin_project]
pub struct BoundJabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    reader: Arc<Mutex<Reader<ReadHalf<S>>>>,
    writer: Arc<Mutex<Writer<WriteHalf<S>>>>,
    write_handle: Option<JoinHandle<Result<(), Error>>>,
    read_handle: Option<JoinHandle<Result<Stanza, Error>>>,
}

impl<S> BoundJabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    // TODO: look into biased mutex, to close stream ASAP
    // TODO: put into connection
    // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
    //     let reader = self.reader.lock().await.into_self();
    //     let writer = self.writer.lock().await.into_self();
    //     // TODO: writer </stream:stream>
    //     return Ok(JabberStream { reader, writer });
    // }
}

pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}

impl<S> Sink<Stanza> for BoundJabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
    type Error = Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.poll_flush(cx)
    }

    fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
        let this = self.project();
        if let Some(_write_handle) = this.write_handle {
            panic!("start_send called without poll_ready")
        } else {
            // TODO: switch to buffer of one rather than thread spawning and joining
            *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item)));
            Ok(())
        }
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        let this = self.project();
        Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() {
            match ready!(join_handle.poll_unpin(cx)) {
                Ok(state) => {
                    *this.write_handle = None;
                    state
                }
                Err(err) => {
                    *this.write_handle = None;
                    Err(err.into())
                }
            }
        } else {
            Ok(())
        })
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.poll_flush(cx)
    }
}

impl<S> Stream for BoundJabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
    type Item = Result<Stanza, Error>;

    fn poll_next(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        let this = self.project();

        loop {
            if let Some(join_handle) = this.read_handle.as_mut() {
                let stanza = ready!(join_handle.poll_unpin(cx));
                if let Ok(item) = stanza {
                    *this.read_handle = None;
                    return Poll::Ready(Some(item));
                } else if let Err(err) = stanza {
                    return Poll::Ready(Some(Err(err.into())));
                }
            } else {
                *this.read_handle = Some(tokio::spawn(read(this.reader.clone())))
            }
        }
    }
}

impl<S> JabberStream<S>
where
    S: AsyncWrite + AsyncRead + Unpin + Send,
{
    pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
        let reader = Arc::new(Mutex::new(self.reader));
        let writer = Arc::new(Mutex::new(self.writer));
        BoundJabberStream {
            writer,
            reader,
            write_handle: None,
            read_handle: None,
        }
    }
}

pub async fn write<W: AsyncWrite + Unpin + Send>(
    writer: Arc<Mutex<Writer<WriteHalf<W>>>>,
    stanza: Stanza,
) -> Result<(), Error> {
    {
        let mut writer = writer.lock().await;
        writer.write(&stanza).await?;
    }
    Ok(())
}

pub async fn read<R: AsyncRead + Unpin + Send>(
    reader: Arc<Mutex<Reader<ReadHalf<R>>>>,
) -> Result<Stanza, Error> {
    let stanza: Result<Stanza, Error>;
    {
        let mut reader = reader.lock().await;
        stanza = reader.read().await.map_err(|e| e.into());
    }
    stanza
}