diff options
author | 2025-01-12 21:19:07 +0000 | |
---|---|---|
committer | 2025-01-12 21:19:07 +0000 | |
commit | e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c (patch) | |
tree | 372426b3286bd9dca98b328536153df61cf8a74c /jabber/src/client.rs | |
parent | 0e5f09b2bd05690f3d28f7076629031fcc2cc6e6 (diff) | |
download | luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.gz luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.bz2 luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.zip |
implement stream splitting and closing
Diffstat (limited to 'jabber/src/client.rs')
-rw-r--r-- | jabber/src/client.rs | 283 |
1 files changed, 68 insertions, 215 deletions
diff --git a/jabber/src/client.rs b/jabber/src/client.rs index 2e59d98..9d32682 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -18,13 +18,13 @@ use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, - jabber_stream::bound_stream::BoundJabberStream, + jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream}, Connection, Error, JabberStream, Result, JID, }; // feed it client stanzas, receive client stanzas pub struct JabberClient { - connection: ConnectionState, + connection: Option<BoundJabberStream<Tls>>, jid: JID, // TODO: have reconnection be handled by another part, so creds don't need to be stored in object password: Arc<SASLConfig>, @@ -43,7 +43,7 @@ impl JabberClient { password.to_string(), )?; Ok(JabberClient { - connection: ConnectionState::Disconnected, + connection: None, jid: jid.clone(), password: sasl_config, server: jid.domainpart, @@ -56,25 +56,19 @@ impl JabberClient { pub async fn connect(&mut self) -> Result<()> { match &self.connection { - ConnectionState::Disconnected => { - // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?) - // perhaps use take_mut? - self.connection = ConnectionState::Disconnected - .connect(&mut self.jid, self.password.clone(), &mut self.server) - .await?; + Some(_) => Ok(()), + None => { + self.connection = Some( + connect_and_login(&mut self.jid, self.password.clone(), &mut self.server) + .await?, + ); Ok(()) } - ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting), - ConnectionState::Connected(_jabber_stream) => Ok(()), } } - pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> { - match self.connection { - ConnectionState::Disconnected => return Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), - ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream), - } + pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> { + self.connection.ok_or(Error::Disconnected) } // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { @@ -88,203 +82,59 @@ impl JabberClient { // } } -impl Sink<Stanza> for JabberClient { - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - self.get_mut().connection.poll_ready_unpin(cx) - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: Stanza, - ) -> std::result::Result<(), Self::Error> { - self.get_mut().connection.start_send_unpin(item) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - self.get_mut().connection.poll_flush_unpin(cx) - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - self.get_mut().connection.poll_flush_unpin(cx) - } -} - -impl Stream for JabberClient { - type Item = Result<Stanza>; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<Option<Self::Item>> { - self.get_mut().connection.poll_next_unpin(cx) - } -} - -pub enum ConnectionState { - Disconnected, - Connecting(Connecting), - Connected(BoundJabberStream<Tls>), -} - -impl Sink<Stanza> for ConnectionState { - type Error = Error; - - fn poll_ready( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_ready_unpin(cx) +pub async fn connect_and_login( + jid: &mut JID, + auth: Arc<SASLConfig>, + server: &mut String, +) -> Result<BoundJabberStream<Tls>> { + let mut conn_state = Connecting::start(&server).await?; + loop { + match conn_state { + Connecting::InsecureConnectionEstablised(tcp_stream) => { + conn_state = Connecting::InsecureStreamStarted( + JabberStream::start_stream(tcp_stream, server).await?, + ) } - } - } - - fn start_send( - self: std::pin::Pin<&mut Self>, - item: Stanza, - ) -> std::result::Result<(), Self::Error> { - match self.get_mut() { - ConnectionState::Disconnected => Err(Error::Disconnected), - ConnectionState::Connecting(_connecting) => Err(Error::Connecting), - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.start_send_unpin(item) + Connecting::InsecureStreamStarted(jabber_stream) => { + conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) } - } - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_flush_unpin(cx) + Connecting::InsecureGotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => { + conn_state = Connecting::StartTls(jabber_stream) + } + // TODO: better error + _ => return Err(Error::TlsRequired), + } } - } - } - - fn poll_close( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<std::result::Result<(), Self::Error>> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_close_unpin(cx) + Connecting::StartTls(jabber_stream) => { + conn_state = + Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?) } - } - } -} - -impl Stream for ConnectionState { - type Item = Result<Stanza>; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<Option<Self::Item>> { - match self.get_mut() { - ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))), - ConnectionState::Connecting(_connecting) => Poll::Pending, - ConnectionState::Connected(bound_jabber_stream) => { - bound_jabber_stream.poll_next_unpin(cx) + Connecting::ConnectionEstablished(tls_stream) => { + conn_state = + Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?) } - } - } -} - -impl ConnectionState { - pub async fn connect( - mut self, - jid: &mut JID, - auth: Arc<SASLConfig>, - server: &mut String, - ) -> Result<Self> { - loop { - match self { - ConnectionState::Disconnected => { - self = ConnectionState::Connecting(Connecting::start(&server).await?); - } - ConnectionState::Connecting(connecting) => match connecting { - Connecting::InsecureConnectionEstablised(tcp_stream) => { - self = ConnectionState::Connecting(Connecting::InsecureStreamStarted( - JabberStream::start_stream(tcp_stream, server).await?, - )) - } - Connecting::InsecureStreamStarted(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::InsecureGotFeatures( - jabber_stream.get_features().await?, - )) - } - Connecting::InsecureGotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => { - self = - ConnectionState::Connecting(Connecting::StartTls(jabber_stream)) - } - // TODO: better error - _ => return Err(Error::TlsRequired), - } - } - Connecting::StartTls(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::ConnectionEstablished( - jabber_stream.starttls(&server).await?, - )) - } - Connecting::ConnectionEstablished(tls_stream) => { - self = ConnectionState::Connecting(Connecting::StreamStarted( - JabberStream::start_stream(tls_stream, server).await?, - )) - } - Connecting::StreamStarted(jabber_stream) => { - self = ConnectionState::Connecting(Connecting::GotFeatures( - jabber_stream.get_features().await?, - )) - } - Connecting::GotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), - Feature::Sasl(mechanisms) => { - self = ConnectionState::Connecting(Connecting::Sasl( - mechanisms, - jabber_stream, - )) - } - Feature::Bind => { - self = ConnectionState::Connecting(Connecting::Bind(jabber_stream)) - } - Feature::Unknown => return Err(Error::Unsupported), - } - } - Connecting::Sasl(mechanisms, jabber_stream) => { - self = ConnectionState::Connecting(Connecting::ConnectionEstablished( - jabber_stream.sasl(mechanisms, auth.clone()).await?, - )) - } - Connecting::Bind(jabber_stream) => { - self = ConnectionState::Connected( - jabber_stream.bind(jid).await?.to_bound_jabber(), - ) + Connecting::StreamStarted(jabber_stream) => { + conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) + } + Connecting::GotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), + Feature::Sasl(mechanisms) => { + conn_state = Connecting::Sasl(mechanisms, jabber_stream) } - }, - connected => return Ok(connected), + Feature::Bind => conn_state = Connecting::Bind(jabber_stream), + Feature::Unknown => return Err(Error::Unsupported), + } + } + Connecting::Sasl(mechanisms, jabber_stream) => { + conn_state = Connecting::ConnectionEstablished( + jabber_stream.sasl(mechanisms, auth.clone()).await?, + ) + } + Connecting::Bind(jabber_stream) => { + return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); } } } @@ -354,12 +204,12 @@ mod tests { sleep(Duration::from_secs(5)).await; let jid = client.jid.clone(); let server = client.server.clone(); - let (mut write, mut read) = client.split(); + let (mut read, mut write) = client.into_inner().unwrap().split(); tokio::join!( async { write - .send(Stanza::Iq(Iq { + .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s1".to_string(), to: Some(server.clone().try_into().unwrap()), @@ -368,9 +218,10 @@ mod tests { query: Some(Query::Ping(Ping)), errors: Vec::new(), })) - .await; + .await + .unwrap(); write - .send(Stanza::Iq(Iq { + .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s2".to_string(), to: Some(server.clone().try_into().unwrap()), @@ -379,11 +230,13 @@ mod tests { query: Some(Query::Ping(Ping)), errors: Vec::new(), })) - .await; + .await + .unwrap(); }, async { - while let Some(stanza) = read.next().await { - info!("{:#?}", stanza); + for _ in 0..2 { + let stanza = read.read::<Stanza>().await.unwrap(); + info!("ping reply: {:#?}", stanza); } } ); |