aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/jabber_stream.rs
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
commite6c97ab82880ad4cd12b05bc1c8f2a0a3413735c (patch)
tree372426b3286bd9dca98b328536153df61cf8a74c /jabber/src/jabber_stream.rs
parent0e5f09b2bd05690f3d28f7076629031fcc2cc6e6 (diff)
downloadluz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.gz
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.bz2
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.zip
implement stream splitting and closing
Diffstat (limited to 'jabber/src/jabber_stream.rs')
-rw-r--r--jabber/src/jabber_stream.rs118
1 files changed, 110 insertions, 8 deletions
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs
index 89890a8..384e6e4 100644
--- a/jabber/src/jabber_stream.rs
+++ b/jabber/src/jabber_stream.rs
@@ -26,8 +26,103 @@ pub mod bound_stream;
// open stream (streams started)
pub struct JabberStream<S> {
- reader: Reader<ReadHalf<S>>,
- pub(crate) writer: Writer<WriteHalf<S>>,
+ reader: JabberReader<S>,
+ writer: JabberWriter<S>,
+}
+
+impl<S> JabberStream<S> {
+ fn split(self) -> (JabberReader<S>, JabberWriter<S>) {
+ let reader = self.reader;
+ let writer = self.writer;
+ (reader, writer)
+ }
+}
+
+pub struct JabberReader<S>(Reader<ReadHalf<S>>);
+
+impl<S> JabberReader<S> {
+ // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner
+ fn new(reader: Reader<ReadHalf<S>>) -> Self {
+ Self(reader)
+ }
+
+ fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> {
+ JabberStream {
+ reader: self,
+ writer,
+ }
+ }
+
+ fn into_inner(self) -> Reader<ReadHalf<S>> {
+ self.0
+ }
+}
+
+impl<S> JabberReader<S>
+where
+ S: AsyncRead + Unpin,
+{
+ pub async fn try_close(&mut self) -> Result<()> {
+ self.read_end_tag().await?;
+ Ok(())
+ }
+}
+
+impl<S> std::ops::Deref for JabberReader<S> {
+ type Target = Reader<ReadHalf<S>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<S> std::ops::DerefMut for JabberReader<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+pub struct JabberWriter<S>(Writer<WriteHalf<S>>);
+
+impl<S> JabberWriter<S> {
+ fn new(writer: Writer<WriteHalf<S>>) -> Self {
+ Self(writer)
+ }
+
+ fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> {
+ JabberStream {
+ reader,
+ writer: self,
+ }
+ }
+
+ fn into_inner(self) -> Writer<WriteHalf<S>> {
+ self.0
+ }
+}
+
+impl<S> JabberWriter<S>
+where
+ S: AsyncWrite + Unpin + Send,
+{
+ pub async fn try_close(&mut self) -> Result<()> {
+ self.write_end().await?;
+ Ok(())
+ }
+}
+
+impl<S> std::ops::Deref for JabberWriter<S> {
+ type Target = Writer<WriteHalf<S>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl<S> std::ops::DerefMut for JabberWriter<S> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
}
impl<S> JabberStream<S>
@@ -119,8 +214,8 @@ where
}
}
}
- let writer = self.writer.into_inner();
- let reader = self.reader.into_inner();
+ let writer = self.writer.into_inner().into_inner();
+ let reader = self.reader.into_inner().into_inner();
let stream = reader.unsplit(writer);
Ok(stream)
}
@@ -223,8 +318,8 @@ where
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
// client to server
let (reader, writer) = tokio::io::split(connection);
- let mut reader = Reader::new(reader);
- let mut writer = Writer::new(writer);
+ let mut reader = JabberReader::new(Reader::new(reader));
+ let mut writer = JabberWriter::new(Writer::new(writer));
// declaration
writer.write_declaration(XML_VERSION).await?;
@@ -262,7 +357,10 @@ where
}
pub fn into_inner(self) -> S {
- self.reader.into_inner().unsplit(self.writer.into_inner())
+ self.reader
+ .into_inner()
+ .into_inner()
+ .unsplit(self.writer.into_inner().into_inner())
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
@@ -280,7 +378,11 @@ impl JabberStream<Unencrypted> {
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
- let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
+ let stream = self
+ .reader
+ .into_inner()
+ .into_inner()
+ .unsplit(self.writer.into_inner().into_inner());
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(domain.as_ref(), stream)
.await