aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
commite6c97ab82880ad4cd12b05bc1c8f2a0a3413735c (patch)
tree372426b3286bd9dca98b328536153df61cf8a74c
parent0e5f09b2bd05690f3d28f7076629031fcc2cc6e6 (diff)
downloadluz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.gz
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.bz2
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.zip
implement stream splitting and closing
-rw-r--r--jabber/Cargo.toml12
-rw-r--r--jabber/src/client.rs283
-rw-r--r--jabber/src/jabber_stream.rs118
-rw-r--r--jabber/src/jabber_stream/bound_stream.rs181
-rw-r--r--stanza/src/client/mod.rs1
5 files changed, 242 insertions, 353 deletions
diff --git a/jabber/Cargo.toml b/jabber/Cargo.toml
index 68dddd9..d070838 100644
--- a/jabber/Cargo.toml
+++ b/jabber/Cargo.toml
@@ -12,7 +12,12 @@ async-trait = "0.1.68"
lazy_static = "1.4.0"
nanoid = "0.4.0"
# TODO: remove unneeded features
-rsasl = { version = "2.0.1", path = "../../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
+rsasl = { version = "2.0.1", default_features = false, features = [
+ "provider_base64",
+ "plain",
+ "config_builder",
+ "scram-sha-1",
+] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
tracing = "0.1.40"
@@ -29,4 +34,7 @@ pin-project = "1.1.7"
[dev-dependencies]
test-log = { version = "0.2", features = ["trace"] }
env_logger = "*"
-tracing-subscriber = {version = "0.3", default-features = false, features = ["env-filter", "fmt"]}
+tracing-subscriber = { version = "0.3", default-features = false, features = [
+ "env-filter",
+ "fmt",
+] }
diff --git a/jabber/src/client.rs b/jabber/src/client.rs
index 2e59d98..9d32682 100644
--- a/jabber/src/client.rs
+++ b/jabber/src/client.rs
@@ -18,13 +18,13 @@ use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
- jabber_stream::bound_stream::BoundJabberStream,
+ jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
- connection: ConnectionState,
+ connection: Option<BoundJabberStream<Tls>>,
jid: JID,
// TODO: have reconnection be handled by another part, so creds don't need to be stored in object
password: Arc<SASLConfig>,
@@ -43,7 +43,7 @@ impl JabberClient {
password.to_string(),
)?;
Ok(JabberClient {
- connection: ConnectionState::Disconnected,
+ connection: None,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
@@ -56,25 +56,19 @@ impl JabberClient {
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
- ConnectionState::Disconnected => {
- // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
- // perhaps use take_mut?
- self.connection = ConnectionState::Disconnected
- .connect(&mut self.jid, self.password.clone(), &mut self.server)
- .await?;
+ Some(_) => Ok(()),
+ None => {
+ self.connection = Some(
+ connect_and_login(&mut self.jid, self.password.clone(), &mut self.server)
+ .await?,
+ );
Ok(())
}
- ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
- ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
- pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
- match self.connection {
- ConnectionState::Disconnected => return Err(Error::Disconnected),
- ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
- ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
- }
+ pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> {
+ self.connection.ok_or(Error::Disconnected)
}
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
@@ -88,203 +82,59 @@ impl JabberClient {
// }
}
-impl Sink<Stanza> for JabberClient {
- type Error = Error;
-
- fn poll_ready(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_ready_unpin(cx)
- }
-
- fn start_send(
- self: std::pin::Pin<&mut Self>,
- item: Stanza,
- ) -> std::result::Result<(), Self::Error> {
- self.get_mut().connection.start_send_unpin(item)
- }
-
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_flush_unpin(cx)
- }
-
- fn poll_close(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_flush_unpin(cx)
- }
-}
-
-impl Stream for JabberClient {
- type Item = Result<Stanza>;
-
- fn poll_next(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- self.get_mut().connection.poll_next_unpin(cx)
- }
-}
-
-pub enum ConnectionState {
- Disconnected,
- Connecting(Connecting),
- Connected(BoundJabberStream<Tls>),
-}
-
-impl Sink<Stanza> for ConnectionState {
- type Error = Error;
-
- fn poll_ready(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_ready_unpin(cx)
+pub async fn connect_and_login(
+ jid: &mut JID,
+ auth: Arc<SASLConfig>,
+ server: &mut String,
+) -> Result<BoundJabberStream<Tls>> {
+ let mut conn_state = Connecting::start(&server).await?;
+ loop {
+ match conn_state {
+ Connecting::InsecureConnectionEstablised(tcp_stream) => {
+ conn_state = Connecting::InsecureStreamStarted(
+ JabberStream::start_stream(tcp_stream, server).await?,
+ )
}
- }
- }
-
- fn start_send(
- self: std::pin::Pin<&mut Self>,
- item: Stanza,
- ) -> std::result::Result<(), Self::Error> {
- match self.get_mut() {
- ConnectionState::Disconnected => Err(Error::Disconnected),
- ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.start_send_unpin(item)
+ Connecting::InsecureStreamStarted(jabber_stream) => {
+ conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
}
- }
- }
-
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_flush_unpin(cx)
+ Connecting::InsecureGotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => {
+ conn_state = Connecting::StartTls(jabber_stream)
+ }
+ // TODO: better error
+ _ => return Err(Error::TlsRequired),
+ }
}
- }
- }
-
- fn poll_close(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_close_unpin(cx)
+ Connecting::StartTls(jabber_stream) => {
+ conn_state =
+ Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
}
- }
- }
-}
-
-impl Stream for ConnectionState {
- type Item = Result<Stanza>;
-
- fn poll_next(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_next_unpin(cx)
+ Connecting::ConnectionEstablished(tls_stream) => {
+ conn_state =
+ Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
}
- }
- }
-}
-
-impl ConnectionState {
- pub async fn connect(
- mut self,
- jid: &mut JID,
- auth: Arc<SASLConfig>,
- server: &mut String,
- ) -> Result<Self> {
- loop {
- match self {
- ConnectionState::Disconnected => {
- self = ConnectionState::Connecting(Connecting::start(&server).await?);
- }
- ConnectionState::Connecting(connecting) => match connecting {
- Connecting::InsecureConnectionEstablised(tcp_stream) => {
- self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
- JabberStream::start_stream(tcp_stream, server).await?,
- ))
- }
- Connecting::InsecureStreamStarted(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
- jabber_stream.get_features().await?,
- ))
- }
- Connecting::InsecureGotFeatures((features, jabber_stream)) => {
- match features.negotiate().ok_or(Error::Negotiation)? {
- Feature::StartTls(_start_tls) => {
- self =
- ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
- }
- // TODO: better error
- _ => return Err(Error::TlsRequired),
- }
- }
- Connecting::StartTls(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
- jabber_stream.starttls(&server).await?,
- ))
- }
- Connecting::ConnectionEstablished(tls_stream) => {
- self = ConnectionState::Connecting(Connecting::StreamStarted(
- JabberStream::start_stream(tls_stream, server).await?,
- ))
- }
- Connecting::StreamStarted(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::GotFeatures(
- jabber_stream.get_features().await?,
- ))
- }
- Connecting::GotFeatures((features, jabber_stream)) => {
- match features.negotiate().ok_or(Error::Negotiation)? {
- Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
- Feature::Sasl(mechanisms) => {
- self = ConnectionState::Connecting(Connecting::Sasl(
- mechanisms,
- jabber_stream,
- ))
- }
- Feature::Bind => {
- self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
- }
- Feature::Unknown => return Err(Error::Unsupported),
- }
- }
- Connecting::Sasl(mechanisms, jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
- jabber_stream.sasl(mechanisms, auth.clone()).await?,
- ))
- }
- Connecting::Bind(jabber_stream) => {
- self = ConnectionState::Connected(
- jabber_stream.bind(jid).await?.to_bound_jabber(),
- )
+ Connecting::StreamStarted(jabber_stream) => {
+ conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
+ }
+ Connecting::GotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
+ Feature::Sasl(mechanisms) => {
+ conn_state = Connecting::Sasl(mechanisms, jabber_stream)
}
- },
- connected => return Ok(connected),
+ Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
+ Feature::Unknown => return Err(Error::Unsupported),
+ }
+ }
+ Connecting::Sasl(mechanisms, jabber_stream) => {
+ conn_state = Connecting::ConnectionEstablished(
+ jabber_stream.sasl(mechanisms, auth.clone()).await?,
+ )
+ }
+ Connecting::Bind(jabber_stream) => {
+ return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
}
}
}
@@ -354,12 +204,12 @@ mod tests {
sleep(Duration::from_secs(5)).await;
let jid = client.jid.clone();
let server = client.server.clone();
- let (mut write, mut read) = client.split();
+ let (mut read, mut write) = client.into_inner().unwrap().split();
tokio::join!(
async {
write
- .send(Stanza::Iq(Iq {
+ .write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
@@ -368,9 +218,10 @@ mod tests {
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
- .await;
+ .await
+ .unwrap();
write
- .send(Stanza::Iq(Iq {
+ .write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
@@ -379,11 +230,13 @@ mod tests {
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
- .await;
+ .await
+ .unwrap();
},
async {
- while let Some(stanza) = read.next().await {
- info!("{:#?}", stanza);
+ for _ in 0..2 {
+ let stanza = read.read::<Stanza>().await.unwrap();
+ info!("ping reply: {:#?}", stanza);
}
}
);
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs
index 89890a8..384e6e4 100644
--- a/jabber/src/jabber_stream.rs
+++ b/jabber/src/jabber_stream.rs
@@ -26,8 +26,103 @@ pub mod bound_stream;
// open stream (streams started)
pub struct JabberStream<S> {
- reader: Reader<ReadHalf<S>>,
- pub(crate) writer: Writer<WriteHalf<S>>,
+ reader: JabberReader<S>,
+ writer: JabberWriter<S>,
+}
+
+impl<S> JabberStream<S> {
+ fn split(self) -> (JabberReader<S>, JabberWriter<S>) {
+ let reader = self.reader;
+ let writer = self.writer;
+ (reader, writer)
+ }
+}
+
+pub struct JabberReader<S>(Reader<ReadHalf<S>>);
+
+impl<S> JabberReader<S> {
+ // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner
+ fn new(reader: Reader<ReadHalf<S>>) -> Self {
+ Self(reader)
+ }
+
+ fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> {
+ JabberStream {
+ reader: self,
+ writer,
+ }
+ }
+
+ fn into_inner(self) -> Reader<ReadHalf<S>> {
+ self.0
+ }
+}
+
+impl<S> JabberReader<S>
+where
+ S: AsyncRead + Unpin,
+{
+ pub async fn try_close(&mut self) -> Result<()> {
+ self.read_end_tag().await?;
+ Ok(())
+ }
+}
+
+impl<S> std::ops::Deref for JabberReader<S> {
+ type Target = Reader<ReadHalf<S>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<S> std::ops::DerefMut for JabberReader<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+pub struct JabberWriter<S>(Writer<WriteHalf<S>>);
+
+impl<S> JabberWriter<S> {
+ fn new(writer: Writer<WriteHalf<S>>) -> Self {
+ Self(writer)
+ }
+
+ fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> {
+ JabberStream {
+ reader,
+ writer: self,
+ }
+ }
+
+ fn into_inner(self) -> Writer<WriteHalf<S>> {
+ self.0
+ }
+}
+
+impl<S> JabberWriter<S>
+where
+ S: AsyncWrite + Unpin + Send,
+{
+ pub async fn try_close(&mut self) -> Result<()> {
+ self.write_end().await?;
+ Ok(())
+ }
+}
+
+impl<S> std::ops::Deref for JabberWriter<S> {
+ type Target = Writer<WriteHalf<S>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<S> std::ops::DerefMut for JabberWriter<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
}
impl<S> JabberStream<S>
@@ -119,8 +214,8 @@ where
}
}
}
- let writer = self.writer.into_inner();
- let reader = self.reader.into_inner();
+ let writer = self.writer.into_inner().into_inner();
+ let reader = self.reader.into_inner().into_inner();
let stream = reader.unsplit(writer);
Ok(stream)
}
@@ -223,8 +318,8 @@ where
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
// client to server
let (reader, writer) = tokio::io::split(connection);
- let mut reader = Reader::new(reader);
- let mut writer = Writer::new(writer);
+ let mut reader = JabberReader::new(Reader::new(reader));
+ let mut writer = JabberWriter::new(Writer::new(writer));
// declaration
writer.write_declaration(XML_VERSION).await?;
@@ -262,7 +357,10 @@ where
}
pub fn into_inner(self) -> S {
- self.reader.into_inner().unsplit(self.writer.into_inner())
+ self.reader
+ .into_inner()
+ .into_inner()
+ .unsplit(self.writer.into_inner().into_inner())
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
@@ -280,7 +378,11 @@ impl JabberStream<Unencrypted> {
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
- let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
+ let stream = self
+ .reader
+ .into_inner()
+ .into_inner()
+ .unsplit(self.writer.into_inner().into_inner());
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(domain.as_ref(), stream)
.await
diff --git a/jabber/src/jabber_stream/bound_stream.rs b/jabber/src/jabber_stream/bound_stream.rs
index 627158a..51a1763 100644
--- a/jabber/src/jabber_stream/bound_stream.rs
+++ b/jabber/src/jabber_stream/bound_stream.rs
@@ -1,165 +1,90 @@
-use std::future::ready;
-use std::pin::pin;
-use std::pin::Pin;
-use std::sync::Arc;
-use std::task::Poll;
-
-use futures::ready;
-use futures::FutureExt;
-use futures::{sink, stream, Sink, Stream};
+use std::ops::{Deref, DerefMut};
+
use peanuts::{Reader, Writer};
-use pin_project::pin_project;
-use stanza::client::Stanza;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
-use tokio::sync::Mutex;
-use tokio::task::JoinHandle;
use crate::Error;
-use super::JabberStream;
+use super::{JabberReader, JabberStream, JabberWriter};
+
+pub struct BoundJabberStream<S>(JabberStream<S>);
-#[pin_project]
-pub struct BoundJabberStream<S>
+impl<S> Deref for BoundJabberStream<S>
where
S: AsyncWrite + AsyncRead + Unpin + Send,
{
- reader: Arc<Mutex<Reader<ReadHalf<S>>>>,
- writer: Arc<Mutex<Writer<WriteHalf<S>>>>,
- write_handle: Option<JoinHandle<Result<(), Error>>>,
- read_handle: Option<JoinHandle<Result<Stanza, Error>>>,
+ type Target = JabberStream<S>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
}
-impl<S> BoundJabberStream<S>
+impl<S> DerefMut for BoundJabberStream<S>
where
S: AsyncWrite + AsyncRead + Unpin + Send,
{
- // TODO: look into biased mutex, to close stream ASAP
- // TODO: put into connection
- // pub async fn close_stream(self) -> Result<JabberStream<S>, Error> {
- // let reader = self.reader.lock().await.into_self();
- // let writer = self.writer.lock().await.into_self();
- // // TODO: writer </stream:stream>
- // return Ok(JabberStream { reader, writer });
- // }
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
}
-pub trait JabberStreamTrait: AsyncWrite + AsyncRead + Unpin + Send {}
+impl<S> BoundJabberStream<S> {
+ pub fn split(self) -> (BoundJabberReader<S>, BoundJabberWriter<S>) {
+ let (reader, writer) = self.0.split();
+ (BoundJabberReader(reader), BoundJabberWriter(writer))
+ }
+}
-impl<S> Sink<Stanza> for BoundJabberStream<S>
-where
- S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
-{
- type Error = Error;
+pub struct BoundJabberReader<S>(JabberReader<S>);
- fn poll_ready(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), Self::Error>> {
- self.poll_flush(cx)
+impl<S> BoundJabberReader<S> {
+ pub fn unsplit(self, writer: BoundJabberWriter<S>) -> BoundJabberStream<S> {
+ BoundJabberStream(self.0.unsplit(writer.0))
}
+}
- fn start_send(self: std::pin::Pin<&mut Self>, item: Stanza) -> Result<(), Self::Error> {
- let this = self.project();
- if let Some(_write_handle) = this.write_handle {
- panic!("start_send called without poll_ready")
- } else {
- // TODO: switch to buffer of one rather than thread spawning and joining
- *this.write_handle = Some(tokio::spawn(write(this.writer.clone(), item)));
- Ok(())
- }
- }
+impl<S> std::ops::Deref for BoundJabberReader<S> {
+ type Target = JabberReader<S>;
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), Self::Error>> {
- let this = self.project();
- Poll::Ready(if let Some(join_handle) = this.write_handle.as_mut() {
- match ready!(join_handle.poll_unpin(cx)) {
- Ok(state) => {
- *this.write_handle = None;
- state
- }
- Err(err) => {
- *this.write_handle = None;
- Err(err.into())
- }
- }
- } else {
- Ok(())
- })
+ fn deref(&self) -> &Self::Target {
+ &self.0
}
+}
- fn poll_close(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Result<(), Self::Error>> {
- self.poll_flush(cx)
+impl<S> std::ops::DerefMut for BoundJabberReader<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
}
}
-impl<S> Stream for BoundJabberStream<S>
-where
- S: AsyncWrite + AsyncRead + Unpin + Send + 'static,
-{
- type Item = Result<Stanza, Error>;
-
- fn poll_next(
- self: Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Self::Item>> {
- let this = self.project();
-
- loop {
- if let Some(join_handle) = this.read_handle.as_mut() {
- let stanza = ready!(join_handle.poll_unpin(cx));
- if let Ok(item) = stanza {
- *this.read_handle = None;
- return Poll::Ready(Some(item));
- } else if let Err(err) = stanza {
- return Poll::Ready(Some(Err(err.into())));
- }
- } else {
- *this.read_handle = Some(tokio::spawn(read(this.reader.clone())))
- }
- }
+pub struct BoundJabberWriter<S>(JabberWriter<S>);
+
+impl<S> BoundJabberWriter<S> {
+ pub fn unsplit(self, reader: BoundJabberReader<S>) -> BoundJabberStream<S> {
+ BoundJabberStream(self.0.unsplit(reader.0))
}
}
-impl<S> JabberStream<S>
-where
- S: AsyncWrite + AsyncRead + Unpin + Send,
-{
- pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
- let reader = Arc::new(Mutex::new(self.reader));
- let writer = Arc::new(Mutex::new(self.writer));
- BoundJabberStream {
- writer,
- reader,
- write_handle: None,
- read_handle: None,
- }
+impl<S> std::ops::Deref for BoundJabberWriter<S> {
+ type Target = JabberWriter<S>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
}
}
-pub async fn write<W: AsyncWrite + Unpin + Send>(
- writer: Arc<Mutex<Writer<WriteHalf<W>>>>,
- stanza: Stanza,
-) -> Result<(), Error> {
- {
- let mut writer = writer.lock().await;
- writer.write(&stanza).await?;
+impl<S> std::ops::DerefMut for BoundJabberWriter<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
}
- Ok(())
}
-pub async fn read<R: AsyncRead + Unpin + Send>(
- reader: Arc<Mutex<Reader<ReadHalf<R>>>>,
-) -> Result<Stanza, Error> {
- let stanza: Result<Stanza, Error>;
- {
- let mut reader = reader.lock().await;
- stanza = reader.read().await.map_err(|e| e.into());
+impl<S> JabberStream<S>
+where
+ S: AsyncWrite + AsyncRead + Unpin + Send,
+{
+ pub fn to_bound_jabber(self) -> BoundJabberStream<S> {
+ BoundJabberStream(self)
}
- stanza
}
diff --git a/stanza/src/client/mod.rs b/stanza/src/client/mod.rs
index e9c336e..11ba616 100644
--- a/stanza/src/client/mod.rs
+++ b/stanza/src/client/mod.rs
@@ -15,6 +15,7 @@ pub mod presence;
pub const XMLNS: &str = "jabber:client";
+/// TODO: End tag
#[derive(Debug)]
pub enum Stanza {
Message(Message),