diff options
Diffstat (limited to 'jabber/src/jabber_stream/bound_stream.rs')
-rw-r--r-- | jabber/src/jabber_stream/bound_stream.rs | 181 |
1 files changed, 53 insertions, 128 deletions
diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index 627158a..51a1763 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,165 +1,90 @@ -use std::future::ready; -use std::pin::pin; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; - -use futures::ready; -use futures::FutureExt; -use futures::{sink, stream, Sink, Stream}; +use std::ops::{Deref, DerefMut}; + use peanuts::{Reader, Writer}; -use pin_project::pin_project; -use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; -use tokio::sync::Mutex; -use tokio::task::JoinHandle; use crate::Error; -use super::JabberStream; +use super::{JabberReader, JabberStream, JabberWriter}; + +pub struct BoundJabberStream<S>(JabberStream<S>); -#[pin_project] -pub struct BoundJabberStream<S> +impl<S> Deref for BoundJabberStream<S> where S: AsyncWrite + AsyncRead + Unpin + Send, { - reader: Arc<Mutex<Reader<ReadHalf<S>>>>, - writer: Arc<Mutex<Writer<WriteHalf<S>>>>, - write_handle: Option<JoinHandle<Result<(), Error>>>, - read_handle: Option<JoinHandle<Result<Stanza, Error>>>, + type Target = JabberStream<S>; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -impl<S> BoundJabberStream<S> +impl<S> DerefMut for BoundJabberStream<S> where S: AsyncWrite + AsyncRead + Unpin + Send, { - // TODO: look into biased mutex, to close stream ASAP - // TODO: put into connection - // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> { - // let reader = self.reader.lock().await.into_self(); - // let writer = self.writer.lock().await.into_self(); - // // TODO: writer </stream:stream> - // return Ok(JabberStream { reader, writer }); - // } + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } -pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} +impl<S> BoundJabberStream<S> { + pub fn split(self) -> (BoundJabberReader<S>, BoundJabberWriter<S>) { + let (reader, writer) = self.0.split(); + (BoundJabberReader(reader), BoundJabberWriter(writer)) + } +} -impl<S> Sink<Stanza> for BoundJabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send + 'static, -{ - type Error = Error; +pub struct BoundJabberReader<S>(JabberReader<S>); - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - self.poll_flush(cx) +impl<S> BoundJabberReader<S> { + pub fn unsplit(self, writer: BoundJabberWriter<S>) -> BoundJabberStream<S> { + BoundJabberStream(self.0.unsplit(writer.0)) } +} - fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { - let this = self.project(); - if let Some(_write_handle) = this.write_handle { - panic!("start_send called without poll_ready") - } else { - // TODO: switch to buffer of one rather than thread spawning and joining - *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); - Ok(()) - } - } +impl<S> std::ops::Deref for BoundJabberReader<S> { + type Target = JabberReader<S>; - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - let this = self.project(); - Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { - match ready!(join_handle.poll_unpin(cx)) { - Ok(state) => { - *this.write_handle = None; - state - } - Err(err) => { - *this.write_handle = None; - Err(err.into()) - } - } - } else { - Ok(()) - }) + fn deref(&self) -> &Self::Target { + &self.0 } +} - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Result<(), Self::Error>> { - self.poll_flush(cx) +impl<S> std::ops::DerefMut for BoundJabberReader<S> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } -impl<S> Stream for BoundJabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send + 'static, -{ - type Item = Result<Stanza, Error>; - - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll<Option<Self::Item>> { - let this = self.project(); - - loop { - if let Some(join_handle) = this.read_handle.as_mut() { - let stanza = ready!(join_handle.poll_unpin(cx)); - if let Ok(item) = stanza { - *this.read_handle = None; - return Poll::Ready(Some(item)); - } else if let Err(err) = stanza { - return Poll::Ready(Some(Err(err.into()))); - } - } else { - *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) - } - } +pub struct BoundJabberWriter<S>(JabberWriter<S>); + +impl<S> BoundJabberWriter<S> { + pub fn unsplit(self, reader: BoundJabberReader<S>) -> BoundJabberStream<S> { + BoundJabberStream(self.0.unsplit(reader.0)) } } -impl<S> JabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send, -{ - pub fn to_bound_jabber(self) -> BoundJabberStream<S> { - let reader = Arc::new(Mutex::new(self.reader)); - let writer = Arc::new(Mutex::new(self.writer)); - BoundJabberStream { - writer, - reader, - write_handle: None, - read_handle: None, - } +impl<S> std::ops::Deref for BoundJabberWriter<S> { + type Target = JabberWriter<S>; + + fn deref(&self) -> &Self::Target { + &self.0 } } -pub async fn write<W: AsyncWrite + Unpin + Send>( - writer: Arc<Mutex<Writer<WriteHalf<W>>>>, - stanza: Stanza, -) -> Result<(), Error> { - { - let mut writer = writer.lock().await; - writer.write(&stanza).await?; +impl<S> std::ops::DerefMut for BoundJabberWriter<S> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } - Ok(()) } -pub async fn read<R: AsyncRead + Unpin + Send>( - reader: Arc<Mutex<Reader<ReadHalf<R>>>>, -) -> Result<Stanza, Error> { - let stanza: Result<Stanza, Error>; - { - let mut reader = reader.lock().await; - stanza = reader.read().await.map_err(|e| e.into()); +impl<S> JabberStream<S> +where + S: AsyncWrite + AsyncRead + Unpin + Send, +{ + pub fn to_bound_jabber(self) -> BoundJabberStream<S> { + BoundJabberStream(self) } - stanza } |