use std::{ borrow::Borrow, future::Future, pin::pin, sync::Arc, task::{ready, Poll}, }; use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt}; use jid::ParseError; use rsasl::config::SASLConfig; use stanza::{ client::Stanza, sasl::Mechanisms, stream::{Feature, Features}, }; use tokio::sync::Mutex; use crate::{ connection::{Tls, Unencrypted}, jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream}, Connection, Error, JabberStream, Result, JID, }; // feed it client stanzas, receive client stanzas pub struct JabberClient { connection: Option>, jid: JID, // TODO: have reconnection be handled by another part, so creds don't need to be stored in object password: Arc, server: String, } impl JabberClient { pub fn new( jid: impl TryInto, password: impl ToString, ) -> Result { let jid = jid.try_into()?; let sasl_config = SASLConfig::with_credentials( None, jid.localpart.clone().ok_or(Error::NoLocalpart)?, password.to_string(), )?; Ok(JabberClient { connection: None, jid: jid.clone(), password: sasl_config, server: jid.domainpart, }) } pub fn jid(&self) -> JID { self.jid.clone() } pub async fn connect(&mut self) -> Result<()> { match &self.connection { Some(_) => Ok(()), None => { self.connection = Some( connect_and_login(&mut self.jid, self.password.clone(), &mut self.server) .await?, ); Ok(()) } } } pub(crate) fn into_inner(self) -> Result> { self.connection.ok_or(Error::Disconnected) } // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { // match &mut self.connection { // ConnectionState::Disconnected => return Err(Error::Disconnected), // ConnectionState::Connecting(_connecting) => return Err(Error::Connecting), // ConnectionState::Connected(jabber_stream) => { // Ok(jabber_stream.send_stanza(stanza).await?) // } // } // } } pub async fn connect_and_login( jid: &mut JID, auth: Arc, server: &mut String, ) -> Result> { let mut conn_state = Connecting::start(&server).await?; loop { match conn_state { Connecting::InsecureConnectionEstablised(tcp_stream) => { conn_state = Connecting::InsecureStreamStarted( JabberStream::start_stream(tcp_stream, server).await?, ) } Connecting::InsecureStreamStarted(jabber_stream) => { conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) } Connecting::InsecureGotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => { conn_state = Connecting::StartTls(jabber_stream) } // TODO: better error _ => return Err(Error::TlsRequired), } } Connecting::StartTls(jabber_stream) => { conn_state = Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?) } Connecting::ConnectionEstablished(tls_stream) => { conn_state = Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?) } Connecting::StreamStarted(jabber_stream) => { conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) } Connecting::GotFeatures((features, jabber_stream)) => { match features.negotiate().ok_or(Error::Negotiation)? { Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), Feature::Sasl(mechanisms) => { conn_state = Connecting::Sasl(mechanisms, jabber_stream) } Feature::Bind => conn_state = Connecting::Bind(jabber_stream), Feature::Unknown => return Err(Error::Unsupported), } } Connecting::Sasl(mechanisms, jabber_stream) => { conn_state = Connecting::ConnectionEstablished( jabber_stream.sasl(mechanisms, auth.clone()).await?, ) } Connecting::Bind(jabber_stream) => { return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); } } } } pub enum Connecting { InsecureConnectionEstablised(Unencrypted), InsecureStreamStarted(JabberStream), InsecureGotFeatures((Features, JabberStream)), StartTls(JabberStream), ConnectionEstablished(Tls), StreamStarted(JabberStream), GotFeatures((Features, JabberStream)), Sasl(Mechanisms, JabberStream), Bind(JabberStream), } impl Connecting { pub async fn start(server: &str) -> Result { match Connection::connect(server).await? { Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), Connection::Unencrypted(tcp_stream) => { Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) } } } } pub enum InsecureConnecting { Disconnected, ConnectionEstablished(Connection), PreStarttls(JabberStream), PreAuthenticated(JabberStream), Authenticated(Tls), PreBound(JabberStream), Bound(JabberStream), } #[cfg(test)] mod tests { use std::{sync::Arc, time::Duration}; use super::JabberClient; use futures::{SinkExt, StreamExt}; use stanza::{ client::{ iq::{Iq, IqType, Query}, Stanza, }, xep_0199::Ping, }; use test_log::test; use tokio::{sync::Mutex, time::sleep}; use tracing::info; #[test(tokio::test)] async fn login() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await } #[test(tokio::test)] async fn ping_parallel() { let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); client.connect().await.unwrap(); sleep(Duration::from_secs(5)).await; let jid = client.jid.clone(); let server = client.server.clone(); let (mut read, mut write) = client.into_inner().unwrap().split(); tokio::join!( async { write .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s1".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await .unwrap(); write .write(&Stanza::Iq(Iq { from: Some(jid.clone()), id: "c2s2".to_string(), to: Some(server.clone().try_into().unwrap()), r#type: IqType::Get, lang: None, query: Some(Query::Ping(Ping)), errors: Vec::new(), })) .await .unwrap(); }, async { for _ in 0..2 { let stanza = read.read::().await.unwrap(); info!("ping reply: {:#?}", stanza); } } ); } }