From 6385e43e8ca467e53c6a705a932016c5af75c3a2 Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Sun, 22 Dec 2024 18:58:28 +0000 Subject: implement sink and stream with tokio::spawn --- jabber/src/jabber_stream/bound_stream.rs | 139 +++++++++++++++++-------------- 1 file changed, 75 insertions(+), 64 deletions(-) (limited to 'jabber/src/jabber_stream') diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs index ca93421..c0d67b0 100644 --- a/jabber/src/jabber_stream/bound_stream.rs +++ b/jabber/src/jabber_stream/bound_stream.rs @@ -1,70 +1,71 @@ +use std::future::ready; use std::pin::pin; use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; +use futures::ready; +use futures::FutureExt; use futures::{sink, stream, Sink, Stream}; use peanuts::{Reader, Writer}; use pin_project::pin_project; use stanza::client::Stanza; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio::sync::Mutex; +use tokio::task::JoinHandle; use crate::Error; use super::JabberStream; #[pin_project] -pub struct BoundJabberStream +pub struct BoundJabberStream where - R: Stream, - W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { - reader: Arc>>>>, - writer: Arc>>>>, - stream: R, - sink: W, + reader: Arc>>>, + writer: Arc>>>, + write_handle: Option>>, + read_handle: Option>>, } -impl BoundJabberStream +impl BoundJabberStream where - R: Stream, - W: Sink, S: AsyncWrite + AsyncRead + Unpin + Send, { // TODO: look into biased mutex, to close stream ASAP - pub async fn close_stream(self) -> Result, Error> { - if let Some(reader) = self.reader.lock().await.take() { - if let Some(writer) = self.writer.lock().await.take() { - // TODO: writer - return Ok(JabberStream { reader, writer }); - } - } - return Err(Error::StreamClosed); - } + // TODO: put into connection + // pub async fn close_stream(self) -> Result, Error> { + // let reader = self.reader.lock().await.into_self(); + // let writer = self.writer.lock().await.into_self(); + // // TODO: writer + // return Ok(JabberStream { reader, writer }); + // } } pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {} -impl Sink for BoundJabberStream +impl Sink for BoundJabberStream where - R: Stream, - W: Sink + Unpin, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Error = >::Error; + type Error = Error; fn poll_ready( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); - pin!(this.sink).poll_ready(cx) + self.poll_flush(cx) } fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> { let this = self.project(); - pin!(this.sink).start_send(item) + if let Some(_write_handle) = this.write_handle { + panic!("start_send called without poll_ready") + } else { + *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item))); + Ok(()) + } } fn poll_flush( @@ -72,32 +73,55 @@ where cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); - pin!(this.sink).poll_flush(cx) + Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() { + match ready!(join_handle.poll_unpin(cx)) { + Ok(state) => { + *this.write_handle = None; + state + } + Err(err) => { + *this.write_handle = None; + Err(err.into()) + } + } + } else { + Ok(()) + }) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let this = self.project(); - pin!(this.sink).poll_close(cx) + self.poll_flush(cx) } } -impl Stream for BoundJabberStream +impl Stream for BoundJabberStream where - R: Stream + Unpin, - W: Sink, - S: AsyncWrite + AsyncRead + Unpin + Send, + S: AsyncWrite + AsyncRead + Unpin + Send + 'static, { - type Item = ::Item; + type Item = Result; fn poll_next( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let this = self.project(); - pin!(this.stream).poll_next(cx) + + loop { + if let Some(join_handle) = this.read_handle.as_mut() { + let stanza = ready!(join_handle.poll_unpin(cx)); + if let Ok(item) = stanza { + *this.read_handle = None; + return Poll::Ready(Some(item)); + } else if let Err(err) = stanza { + return Poll::Ready(Some(Err(err.into()))); + } + } else { + *this.read_handle = Some(tokio::spawn(read(this.reader.clone()))) + } + } } } @@ -105,49 +129,36 @@ impl JabberStream where S: AsyncWrite + AsyncRead + Unpin + Send, { - pub fn to_bound_jabber(self) -> BoundJabberStream, S> { - let reader = Arc::new(Mutex::new(Some(self.reader))); - let writer = Arc::new(Mutex::new(Some(self.writer))); - let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move { - write(writer, s).await - }); - let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await }); + pub fn to_bound_jabber(self) -> BoundJabberStream { + let reader = Arc::new(Mutex::new(self.reader)); + let writer = Arc::new(Mutex::new(self.writer)); BoundJabberStream { - sink, - stream, writer, reader, + write_handle: None, + read_handle: None, } } } pub async fn write( - writer: Arc>>>>, + writer: Arc>>>, stanza: Stanza, -) -> Result>>>>, Error> { +) -> Result<(), Error> { { - if let Some(writer) = writer.lock().await.as_mut() { - writer.write(&stanza).await?; - } else { - return Err(Error::StreamClosed); - } + let mut writer = writer.lock().await; + writer.write(&stanza).await?; } - Ok(writer) + Ok(()) } pub async fn read( - reader: Arc>>>>, -) -> Option<( - Result, - Arc>>>>, -)> { + reader: Arc>>>, +) -> Result { let stanza: Result; { - if let Some(reader) = reader.lock().await.as_mut() { - stanza = reader.read().await.map_err(|e| e.into()); - } else { - stanza = Err(Error::StreamClosed) - }; + let mut reader = reader.lock().await; + stanza = reader.read().await.map_err(|e| e.into()); } - Some((stanza, reader)) + stanza } -- cgit