aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/jabber_stream/bound_stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'jabber/src/jabber_stream/bound_stream.rs')
-rw-r--r--jabber/src/jabber_stream/bound_stream.rs139
1 files changed, 75 insertions, 64 deletions
diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs
index ca93421..c0d67b0 100644
--- a/jabber/src/jabber_stream/bound_stream.rs
+++ b/jabber/src/jabber_stream/bound_stream.rs
@@ -1,70 +1,71 @@
+use std::future::ready;
use std::pin::pin;
use std::pin::Pin;
use std::sync::Arc;
+use std::task::Poll;
+use futures::ready;
+use futures::FutureExt;
use futures::{sink, stream, Sink, Stream};
use peanuts::{Reader, Writer};
use pin_project::pin_project;
use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
+use tokio::task::JoinHandle;
use crate::Error;
use super::JabberStream;
#[pin_project]
-pub struct BoundJabberStream<R, W, S>
+pub struct BoundJabberStream<S>
where
- R: Stream,
- W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
- reader: Arc<Mutex<Option<Reader<ReadHalf<S>>>>>,
- writer: Arc<Mutex<Option<Writer<WriteHalf<S>>>>>,
- stream: R,
- sink: W,
+ reader: Arc<Mutex<Reader<ReadHalf<S>>>>,
+ writer: Arc<Mutex<Writer<WriteHalf<S>>>>,
+ write_handle: Option<JoinHandle<Result<(), Error>>>,
+ read_handle: Option<JoinHandle<Result<Stanza, Error>>>,
}
-impl<R, W, S> BoundJabberStream<R, W, S>
+impl<S> BoundJabberStream<S>
where
- R: Stream,
- W: Sink<Stanza>,
S: AsyncWrite + AsyncRead + Unpin + Send,
{
// TODO: look into biased mutex, to close stream ASAP
- pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
- if let Some(reader) = self.reader.lock().await.take() {
- if let Some(writer) = self.writer.lock().await.take() {
- // TODO: writer </stream:stream>
- return Ok(JabberStream { reader, writer });
- }
- }
- return Err(Error::StreamClosed);
- }
+ // TODO: put into connection
+ // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
+ // let reader = self.reader.lock().await.into_self();
+ // let writer = self.writer.lock().await.into_self();
+ // // TODO: writer </stream:stream>
+ // return Ok(JabberStream { reader, writer });
+ // }
}
pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
-impl<R, W, S> Sink<Stanza> for BoundJabberStream<R, W, S>
+impl<S> Sink<Stanza> for BoundJabberStream<S>
where
- R: Stream,
- W: Sink<Stanza> + Unpin,
- S: AsyncWrite + AsyncRead + Unpin + Send,
+ S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
- type Error = <W as Sink<Stanza>>::Error;
+ type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
- let this = self.project();
- pin!(this.sink).poll_ready(cx)
+ self.poll_flush(cx)
}
fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
let this = self.project();
- pin!(this.sink).start_send(item)
+ if let Some(_write_handle) = this.write_handle {
+ panic!("start_send called without poll_ready")
+ } else {
+ *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item)));
+ Ok(())
+ }
}
fn poll_flush(
@@ -72,32 +73,55 @@ where
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
let this = self.project();
- pin!(this.sink).poll_flush(cx)
+ Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() {
+ match ready!(join_handle.poll_unpin(cx)) {
+ Ok(state) => {
+ *this.write_handle = None;
+ state
+ }
+ Err(err) => {
+ *this.write_handle = None;
+ Err(err.into())
+ }
+ }
+ } else {
+ Ok(())
+ })
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
- let this = self.project();
- pin!(this.sink).poll_close(cx)
+ self.poll_flush(cx)
}
}
-impl<R, W, S> Stream for BoundJabberStream<R, W, S>
+impl<S> Stream for BoundJabberStream<S>
where
- R: Stream + Unpin,
- W: Sink<Stanza>,
- S: AsyncWrite + AsyncRead + Unpin + Send,
+ S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
{
- type Item = <R as Stream>::Item;
+ type Item = Result<Stanza, Error>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
- pin!(this.stream).poll_next(cx)
+
+ loop {
+ if let Some(join_handle) = this.read_handle.as_mut() {
+ let stanza = ready!(join_handle.poll_unpin(cx));
+ if let Ok(item) = stanza {
+ *this.read_handle = None;
+ return Poll::Ready(Some(item));
+ } else if let Err(err) = stanza {
+ return Poll::Ready(Some(Err(err.into())));
+ }
+ } else {
+ *this.read_handle = Some(tokio::spawn(read(this.reader.clone())))
+ }
+ }
}
}
@@ -105,49 +129,36 @@ impl<S> JabberStream<S>
where
S: AsyncWrite + AsyncRead + Unpin + Send,
{
- pub fn to_bound_jabber(self) -> BoundJabberStream<impl Stream, impl Sink<Stanza>, S> {
- let reader = Arc::new(Mutex::new(Some(self.reader)));
- let writer = Arc::new(Mutex::new(Some(self.writer)));
- let sink = sink::unfold(writer.clone(), |writer, s: Stanza| async move {
- write(writer, s).await
- });
- let stream = stream::unfold(reader.clone(), |reader| async { read(reader).await });
+ pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
+ let reader = Arc::new(Mutex::new(self.reader));
+ let writer = Arc::new(Mutex::new(self.writer));
BoundJabberStream {
- sink,
- stream,
writer,
reader,
+ write_handle: None,
+ read_handle: None,
}
}
}
pub async fn write<W: AsyncWrite + Unpin + Send>(
- writer: Arc<Mutex<Option<Writer<WriteHalf<W>>>>>,
+ writer: Arc<Mutex<Writer<WriteHalf<W>>>>,
stanza: Stanza,
-) -> Result<Arc<Mutex<Option<Writer<WriteHalf<W>>>>>, Error> {
+) -> Result<(), Error> {
{
- if let Some(writer) = writer.lock().await.as_mut() {
- writer.write(&stanza).await?;
- } else {
- return Err(Error::StreamClosed);
- }
+ let mut writer = writer.lock().await;
+ writer.write(&stanza).await?;
}
- Ok(writer)
+ Ok(())
}
pub async fn read<R: AsyncRead + Unpin + Send>(
- reader: Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
-) -> Option<(
- Result<Stanza, Error>,
- Arc<Mutex<Option<Reader<ReadHalf<R>>>>>,
-)> {
+ reader: Arc<Mutex<Reader<ReadHalf<R>>>>,
+) -> Result<Stanza, Error> {
let stanza: Result<Stanza, Error>;
{
- if let Some(reader) = reader.lock().await.as_mut() {
- stanza = reader.read().await.map_err(|e| e.into());
- } else {
- stanza = Err(Error::StreamClosed)
- };
+ let mut reader = reader.lock().await;
+ stanza = reader.read().await.map_err(|e| e.into());
}
- Some((stanza, reader))
+ stanza
}