diff options
Diffstat (limited to '')
-rw-r--r-- | src/connection.rs | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 0000000..24f7745 --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,162 @@ +use std::net::{IpAddr, SocketAddr}; +use std::str; +use std::str::FromStr; + +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; +// TODO: use rustls +use tokio_native_tls::TlsStream; + +use crate::Jabber; +use crate::JabberError; +use crate::Result; + +pub type Tls = TlsStream<TcpStream>; +pub type Unencrypted = TcpStream; + +pub enum Connection { + Encrypted(Jabber<Tls>), + Unencrypted(Jabber<Unencrypted>), +} + +impl Connection { + pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { + match self { + Connection::Encrypted(j) => Ok(j), + Connection::Unencrypted(j) => Ok(j.starttls().await?), + } + } + + // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> { + // let server = jid.domainpart.clone(); + // let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; + // println!("auth: {:?}", auth); + // Self::connect(&server, jid.try_into()?, Some(auth)).await + // } + + async fn connect(server: &str) -> Result<Self> { + let sockets = Self::get_sockets(&server).await; + for (socket_addr, tls) in sockets { + match tls { + true => { + if let Ok(connection) = Self::connect_tls(socket_addr, &server).await { + let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Encrypted(Jabber::new( + readhalf, + writehalf, + None, + None, + server.to_owned(), + ))); + } + } + false => { + if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { + let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Unencrypted(Jabber::new( + readhalf, + writehalf, + None, + None, + server.to_owned(), + ))); + } + } + } + } + Err(JabberError::Connection) + } + + async fn get_sockets(domain: &str) -> Vec<(SocketAddr, bool)> { + let mut socket_addrs = Vec::new(); + + // if it's a socket/ip then just return that + + // socket + if let Ok(socket_addr) = SocketAddr::from_str(domain) { + match socket_addr.port() { + 5223 => socket_addrs.push((socket_addr, true)), + _ => socket_addrs.push((socket_addr, false)), + } + + return socket_addrs; + } + // ip + if let Ok(ip) = IpAddr::from_str(domain) { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + return socket_addrs; + } + + // otherwise resolve + if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpp-client._tcp.{}", domain)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) + } + }); + } + } + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpps-client._tcp.{}", domain)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) + } + }); + } + } + + // in case cannot connect through SRV records + resolver.lookup_ip(domain).await.map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + } + }); + } + socket_addrs + } + + /// establishes a connection to the server + pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { + let socket = TcpStream::connect(socket_addr) + .await + .map_err(|_| JabberError::Connection)?; + let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?; + tokio_native_tls::TlsConnector::from(connector) + .connect(domain_name, socket) + .await + .map_err(|_| JabberError::Connection) + } + + pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { + TcpStream::connect(socket_addr) + .await + .map_err(|_| JabberError::Connection) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn connect() { + Connection::connect("blos.sm").await.unwrap(); + } +} |