From e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Sun, 12 Jan 2025 21:19:07 +0000 Subject: implement stream splitting and closing --- jabber/src/jabber_stream.rs | 118 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 8 deletions(-) (limited to 'jabber/src/jabber_stream.rs') diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index 89890a8..384e6e4 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -26,8 +26,103 @@ pub mod bound_stream; // open stream (streams started) pub struct JabberStream { - reader: Reader>, - pub(crate) writer: Writer>, + reader: JabberReader, + writer: JabberWriter, +} + +impl JabberStream { + fn split(self) -> (JabberReader, JabberWriter) { + let reader = self.reader; + let writer = self.writer; + (reader, writer) + } +} + +pub struct JabberReader(Reader>); + +impl JabberReader { + // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner + fn new(reader: Reader>) -> Self { + Self(reader) + } + + fn unsplit(self, writer: JabberWriter) -> JabberStream { + JabberStream { + reader: self, + writer, + } + } + + fn into_inner(self) -> Reader> { + self.0 + } +} + +impl JabberReader +where + S: AsyncRead + Unpin, +{ + pub async fn try_close(&mut self) -> Result<()> { + self.read_end_tag().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberReader { + type Target = Reader>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub struct JabberWriter(Writer>); + +impl JabberWriter { + fn new(writer: Writer>) -> Self { + Self(writer) + } + + fn unsplit(self, reader: JabberReader) -> JabberStream { + JabberStream { + reader, + writer: self, + } + } + + fn into_inner(self) -> Writer> { + self.0 + } +} + +impl JabberWriter +where + S: AsyncWrite + Unpin + Send, +{ + pub async fn try_close(&mut self) -> Result<()> { + self.write_end().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberWriter { + type Target = Writer>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } impl JabberStream @@ -119,8 +214,8 @@ where } } } - let writer = self.writer.into_inner(); - let reader = self.reader.into_inner(); + let writer = self.writer.into_inner().into_inner(); + let reader = self.reader.into_inner().into_inner(); let stream = reader.unsplit(writer); Ok(stream) } @@ -223,8 +318,8 @@ where pub async fn start_stream(connection: S, server: &mut String) -> Result { // client to server let (reader, writer) = tokio::io::split(connection); - let mut reader = Reader::new(reader); - let mut writer = Writer::new(writer); + let mut reader = JabberReader::new(Reader::new(reader)); + let mut writer = JabberWriter::new(Writer::new(writer)); // declaration writer.write_declaration(XML_VERSION).await?; @@ -262,7 +357,10 @@ where } pub fn into_inner(self) -> S { - self.reader.into_inner().unsplit(self.writer.into_inner()) + self.reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner().into_inner()) } pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { @@ -280,7 +378,11 @@ impl JabberStream { let proceed: Proceed = self.reader.read().await?; debug!("got proceed: {:?}", proceed); let connector = TlsConnector::new().unwrap(); - let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); + let stream = self + .reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner().into_inner()); if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) .connect(domain.as_ref(), stream) .await -- cgit