use rsasl::config::SASLConfig; use stanza::{ sasl::Mechanisms, stream::{Feature, Features}, }; use crate::{ connection::{Tls, Unencrypted}, jabber_stream::bound_stream::BoundJabberStream, Connection, Error, JabberStream, Result, JID, }; pub async fn connect_and_login( jid: &mut JID, password: impl AsRef, server: &mut String, ) -> Result> { let auth = SASLConfig::with_credentials( None, jid.localpart.clone().ok_or(Error::NoLocalpart)?, password.as_ref().to_string(), ) .map_err(|e| Error::SASL(e.into()))?; let mut conn_state = Connecting::start(&server).await?; loop { match conn_state { Connecting::InsecureConnectionEstablised(tcp_stream) => { conn_state = Connecting::InsecureStreamStarted( JabberStream::start_stream(tcp_stream, server).await?, ) } Connecting::InsecureStreamStarted(jabber_stream) => { conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) } Connecting::InsecureGotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => { conn_state = Connecting::StartTls(jabber_stream) } // TODO: better error _ => return Err(Error::TlsRequired), } } Connecting::StartTls(jabber_stream) => { conn_state = Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?) } Connecting::ConnectionEstablished(tls_stream) => { conn_state = Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?) } Connecting::StreamStarted(jabber_stream) => { conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) } Connecting::GotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), Feature::Sasl(mechanisms) => { conn_state = Connecting::Sasl(mechanisms, jabber_stream) } Feature::Bind => conn_state = Connecting::Bind(jabber_stream), Feature::Unknown => return Err(Error::Unsupported), } } Connecting::Sasl(mechanisms, jabber_stream) => { conn_state = Connecting::ConnectionEstablished( jabber_stream.sasl(mechanisms, auth.clone()).await?, ) } Connecting::Bind(jabber_stream) => { return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); } } } } pub enum Connecting { InsecureConnectionEstablised(Unencrypted), InsecureStreamStarted(JabberStream), InsecureGotFeatures((Features, JabberStream)), StartTls(JabberStream), ConnectionEstablished(Tls), StreamStarted(JabberStream), GotFeatures((Features, JabberStream)), Sasl(Mechanisms, JabberStream), Bind(JabberStream), } impl Connecting { pub async fn start(server: &str) -> Result { match Connection::connect(server).await? { Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), Connection::Unencrypted(tcp_stream) => { Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) } } } } pub enum InsecureConnecting { Disconnected, ConnectionEstablished(Connection), PreStarttls(JabberStream), PreAuthenticated(JabberStream), Authenticated(Tls), PreBound(JabberStream), Bound(JabberStream), } #[cfg(test)] mod tests { use std::time::Duration; use jid::JID; use stanza::{ client::{ iq::{Iq, IqType, Query}, Stanza, }, xep_0199::Ping, }; use test_log::test; use tokio::time::sleep; use tracing::info; use super::connect_and_login; #[test(tokio::test)] async fn login() { let mut jid: JID = "test@blos.sm".try_into().unwrap(); let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) .await .unwrap(); sleep(Duration::from_secs(5)).await } #[test(tokio::test)] async fn ping_parallel() { let mut jid: JID = "test@blos.sm".try_into().unwrap(); let mut server = "blos.sm".to_string(); let client = connect_and_login(&mut jid, "slayed", &mut server) .await .unwrap(); let (mut read, mut write) = client.split(); tokio::join!( async { write .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s1".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await .unwrap(); write .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s2".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await .unwrap(); }, async { for _ in 0..2 { let stanza = read.read::().await.unwrap(); info!("ping reply: {:#?}", stanza); } } ); } }