aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/client.rs
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2025-01-12 21:19:07 +0000
commite6c97ab82880ad4cd12b05bc1c8f2a0a3413735c (patch)
tree372426b3286bd9dca98b328536153df61cf8a74c /jabber/src/client.rs
parent0e5f09b2bd05690f3d28f7076629031fcc2cc6e6 (diff)
downloadluz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.gz
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.tar.bz2
luz-e6c97ab82880ad4cd12b05bc1c8f2a0a3413735c.zip
implement stream splitting and closing
Diffstat (limited to 'jabber/src/client.rs')
-rw-r--r--jabber/src/client.rs283
1 files changed, 68 insertions, 215 deletions
diff --git a/jabber/src/client.rs b/jabber/src/client.rs
index 2e59d98..9d32682 100644
--- a/jabber/src/client.rs
+++ b/jabber/src/client.rs
@@ -18,13 +18,13 @@ use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
- jabber_stream::bound_stream::BoundJabberStream,
+ jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
- connection: ConnectionState,
+ connection: Option<BoundJabberStream<Tls>>,
jid: JID,
// TODO: have reconnection be handled by another part, so creds don't need to be stored in object
password: Arc<SASLConfig>,
@@ -43,7 +43,7 @@ impl JabberClient {
password.to_string(),
)?;
Ok(JabberClient {
- connection: ConnectionState::Disconnected,
+ connection: None,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
@@ -56,25 +56,19 @@ impl JabberClient {
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
- ConnectionState::Disconnected => {
- // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
- // perhaps use take_mut?
- self.connection = ConnectionState::Disconnected
- .connect(&mut self.jid, self.password.clone(), &mut self.server)
- .await?;
+ Some(_) => Ok(()),
+ None => {
+ self.connection = Some(
+ connect_and_login(&mut self.jid, self.password.clone(), &mut self.server)
+ .await?,
+ );
Ok(())
}
- ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
- ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
- pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
- match self.connection {
- ConnectionState::Disconnected => return Err(Error::Disconnected),
- ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
- ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
- }
+ pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> {
+ self.connection.ok_or(Error::Disconnected)
}
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
@@ -88,203 +82,59 @@ impl JabberClient {
// }
}
-impl Sink<Stanza> for JabberClient {
- type Error = Error;
-
- fn poll_ready(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_ready_unpin(cx)
- }
-
- fn start_send(
- self: std::pin::Pin<&mut Self>,
- item: Stanza,
- ) -> std::result::Result<(), Self::Error> {
- self.get_mut().connection.start_send_unpin(item)
- }
-
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_flush_unpin(cx)
- }
-
- fn poll_close(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- self.get_mut().connection.poll_flush_unpin(cx)
- }
-}
-
-impl Stream for JabberClient {
- type Item = Result<Stanza>;
-
- fn poll_next(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- self.get_mut().connection.poll_next_unpin(cx)
- }
-}
-
-pub enum ConnectionState {
- Disconnected,
- Connecting(Connecting),
- Connected(BoundJabberStream<Tls>),
-}
-
-impl Sink<Stanza> for ConnectionState {
- type Error = Error;
-
- fn poll_ready(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_ready_unpin(cx)
+pub async fn connect_and_login(
+ jid: &mut JID,
+ auth: Arc<SASLConfig>,
+ server: &mut String,
+) -> Result<BoundJabberStream<Tls>> {
+ let mut conn_state = Connecting::start(&server).await?;
+ loop {
+ match conn_state {
+ Connecting::InsecureConnectionEstablised(tcp_stream) => {
+ conn_state = Connecting::InsecureStreamStarted(
+ JabberStream::start_stream(tcp_stream, server).await?,
+ )
}
- }
- }
-
- fn start_send(
- self: std::pin::Pin<&mut Self>,
- item: Stanza,
- ) -> std::result::Result<(), Self::Error> {
- match self.get_mut() {
- ConnectionState::Disconnected => Err(Error::Disconnected),
- ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.start_send_unpin(item)
+ Connecting::InsecureStreamStarted(jabber_stream) => {
+ conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
}
- }
- }
-
- fn poll_flush(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_flush_unpin(cx)
+ Connecting::InsecureGotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => {
+ conn_state = Connecting::StartTls(jabber_stream)
+ }
+ // TODO: better error
+ _ => return Err(Error::TlsRequired),
+ }
}
- }
- }
-
- fn poll_close(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<std::result::Result<(), Self::Error>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_close_unpin(cx)
+ Connecting::StartTls(jabber_stream) => {
+ conn_state =
+ Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
}
- }
- }
-}
-
-impl Stream for ConnectionState {
- type Item = Result<Stanza>;
-
- fn poll_next(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- match self.get_mut() {
- ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
- ConnectionState::Connecting(_connecting) => Poll::Pending,
- ConnectionState::Connected(bound_jabber_stream) => {
- bound_jabber_stream.poll_next_unpin(cx)
+ Connecting::ConnectionEstablished(tls_stream) => {
+ conn_state =
+ Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
}
- }
- }
-}
-
-impl ConnectionState {
- pub async fn connect(
- mut self,
- jid: &mut JID,
- auth: Arc<SASLConfig>,
- server: &mut String,
- ) -> Result<Self> {
- loop {
- match self {
- ConnectionState::Disconnected => {
- self = ConnectionState::Connecting(Connecting::start(&server).await?);
- }
- ConnectionState::Connecting(connecting) => match connecting {
- Connecting::InsecureConnectionEstablised(tcp_stream) => {
- self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
- JabberStream::start_stream(tcp_stream, server).await?,
- ))
- }
- Connecting::InsecureStreamStarted(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
- jabber_stream.get_features().await?,
- ))
- }
- Connecting::InsecureGotFeatures((features, jabber_stream)) => {
- match features.negotiate().ok_or(Error::Negotiation)? {
- Feature::StartTls(_start_tls) => {
- self =
- ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
- }
- // TODO: better error
- _ => return Err(Error::TlsRequired),
- }
- }
- Connecting::StartTls(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
- jabber_stream.starttls(&server).await?,
- ))
- }
- Connecting::ConnectionEstablished(tls_stream) => {
- self = ConnectionState::Connecting(Connecting::StreamStarted(
- JabberStream::start_stream(tls_stream, server).await?,
- ))
- }
- Connecting::StreamStarted(jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::GotFeatures(
- jabber_stream.get_features().await?,
- ))
- }
- Connecting::GotFeatures((features, jabber_stream)) => {
- match features.negotiate().ok_or(Error::Negotiation)? {
- Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
- Feature::Sasl(mechanisms) => {
- self = ConnectionState::Connecting(Connecting::Sasl(
- mechanisms,
- jabber_stream,
- ))
- }
- Feature::Bind => {
- self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
- }
- Feature::Unknown => return Err(Error::Unsupported),
- }
- }
- Connecting::Sasl(mechanisms, jabber_stream) => {
- self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
- jabber_stream.sasl(mechanisms, auth.clone()).await?,
- ))
- }
- Connecting::Bind(jabber_stream) => {
- self = ConnectionState::Connected(
- jabber_stream.bind(jid).await?.to_bound_jabber(),
- )
+ Connecting::StreamStarted(jabber_stream) => {
+ conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
+ }
+ Connecting::GotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
+ Feature::Sasl(mechanisms) => {
+ conn_state = Connecting::Sasl(mechanisms, jabber_stream)
}
- },
- connected => return Ok(connected),
+ Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
+ Feature::Unknown => return Err(Error::Unsupported),
+ }
+ }
+ Connecting::Sasl(mechanisms, jabber_stream) => {
+ conn_state = Connecting::ConnectionEstablished(
+ jabber_stream.sasl(mechanisms, auth.clone()).await?,
+ )
+ }
+ Connecting::Bind(jabber_stream) => {
+ return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
}
}
}
@@ -354,12 +204,12 @@ mod tests {
sleep(Duration::from_secs(5)).await;
let jid = client.jid.clone();
let server = client.server.clone();
- let (mut write, mut read) = client.split();
+ let (mut read, mut write) = client.into_inner().unwrap().split();
tokio::join!(
async {
write
- .send(Stanza::Iq(Iq {
+ .write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
@@ -368,9 +218,10 @@ mod tests {
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
- .await;
+ .await
+ .unwrap();
write
- .send(Stanza::Iq(Iq {
+ .write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
@@ -379,11 +230,13 @@ mod tests {
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
- .await;
+ .await
+ .unwrap();
},
async {
- while let Some(stanza) = read.next().await {
- info!("{:#?}", stanza);
+ for _ in 0..2 {
+ let stanza = read.read::<Stanza>().await.unwrap();
+ info!("ping reply: {:#?}", stanza);
}
}
);