diff options
Diffstat (limited to 'src/connection.rs')
-rw-r--r-- | src/connection.rs | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/src/connection.rs b/src/connection.rs index 65e9383..9e485d3 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,16 +1,18 @@ use std::net::{IpAddr, SocketAddr}; use std::str; use std::str::FromStr; +use std::sync::Arc; +use rsasl::config::SASLConfig; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; // TODO: use rustls use tokio_native_tls::TlsStream; use tracing::{debug, info, instrument, trace}; -use crate::Error; use crate::Jabber; use crate::Result; +use crate::{Error, JID}; pub type Tls = TlsStream<TcpStream>; pub type Unencrypted = TcpStream; @@ -37,15 +39,20 @@ impl Connection { } } - // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> { - // let server = jid.domainpart.clone(); - // let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; - // println!("auth: {:?}", auth); - // Self::connect(&server, jid.try_into()?, Some(auth)).await - // } + pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> { + let jid: JID = JID::from_str(jid.as_ref())?; + let server = jid.domainpart.clone(); + let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; + println!("auth: {:?}", auth); + Self::connect(&server, Some(jid), Some(auth)).await + } #[instrument] - pub async fn connect(server: &str) -> Result<Self> { + pub async fn connect( + server: &str, + jid: Option<JID>, + auth: Option<Arc<SASLConfig>>, + ) -> Result<Self> { info!("connecting to {}", server); let sockets = Self::get_sockets(&server).await; debug!("discovered sockets: {:?}", sockets); @@ -58,8 +65,8 @@ impl Connection { return Ok(Self::Encrypted(Jabber::new( readhalf, writehalf, - None, - None, + jid, + auth, server.to_owned(), ))); } @@ -71,8 +78,8 @@ impl Connection { return Ok(Self::Unencrypted(Jabber::new( readhalf, writehalf, - None, - None, + jid, + auth, server.to_owned(), ))); } @@ -181,12 +188,12 @@ mod tests { #[test(tokio::test)] async fn connect() { - Connection::connect("blos.sm").await.unwrap(); + Connection::connect("blos.sm", None, None).await.unwrap(); } #[test(tokio::test)] async fn test_tls() { - Connection::connect("blos.sm") + Connection::connect("blos.sm", None, None) .await .unwrap() .ensure_tls() |