diff options
Diffstat (limited to 'jabber/src/connection.rs')
-rw-r--r-- | jabber/src/connection.rs | 184 |
1 files changed, 184 insertions, 0 deletions
diff --git a/jabber/src/connection.rs b/jabber/src/connection.rs new file mode 100644 index 0000000..bc5a282 --- /dev/null +++ b/jabber/src/connection.rs @@ -0,0 +1,184 @@ +use std::net::{IpAddr, SocketAddr}; +use std::str; +use std::str::FromStr; +use std::sync::Arc; + +use rsasl::config::SASLConfig; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; +// TODO: use rustls +use tokio_native_tls::TlsStream; +use tracing::{debug, info, instrument, trace}; + +use crate::Result; +use crate::{Error, JID}; + +pub type Tls = TlsStream<TcpStream>; +pub type Unencrypted = TcpStream; + +#[derive(Debug)] +pub enum Connection { + Encrypted(Tls), + Unencrypted(Unencrypted), +} + +impl Connection { + // #[instrument] + /// stream not started + // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { + // match self { + // Connection::Encrypted(j) => Ok(j), + // Connection::Unencrypted(mut j) => { + // j.start_stream().await?; + // info!("upgrading connection to tls"); + // j.get_features().await?; + // let j = j.starttls().await?; + // Ok(j) + // } + // } + // } + + pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> { + let jid: JID = JID::from_str(jid.as_ref())?; + let server = jid.domainpart.clone(); + Self::connect(&server).await + } + + #[instrument] + pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> { + info!("connecting to {}", server.as_ref()); + let sockets = Self::get_sockets(server.as_ref()).await; + debug!("discovered sockets: {:?}", sockets); + for (socket_addr, tls) in sockets { + match tls { + true => { + if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await { + info!("connected via encrypted stream to {}", socket_addr); + // let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Encrypted(connection)); + } + } + false => { + if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { + info!("connected via unencrypted stream to {}", socket_addr); + // let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Unencrypted(connection)); + } + } + } + } + Err(Error::Connection) + } + + #[instrument] + async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> { + let mut socket_addrs = Vec::new(); + + // if it's a socket/ip then just return that + + // socket + trace!("checking if address is a socket address"); + if let Ok(socket_addr) = SocketAddr::from_str(address) { + debug!("{} is a socket address", address); + match socket_addr.port() { + 5223 => socket_addrs.push((socket_addr, true)), + _ => socket_addrs.push((socket_addr, false)), + } + + return socket_addrs; + } + // ip + trace!("checking if address is an ip"); + if let Ok(ip) = IpAddr::from_str(address) { + debug!("{} is an ip", address); + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + return socket_addrs; + } + + // otherwise resolve + debug!("resolving {}", address); + if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpp-client._tcp.{}", address)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) + } + }); + } + } + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpps-client._tcp.{}", address)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) + } + }); + } + } + + // in case cannot connect through SRV records + resolver.lookup_ip(address).await.map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + } + }); + } + socket_addrs + } + + /// establishes a connection to the server + #[instrument] + pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { + let socket = TcpStream::connect(socket_addr) + .await + .map_err(|_| Error::Connection)?; + let connector = TlsConnector::new().map_err(|_| Error::Connection)?; + tokio_native_tls::TlsConnector::from(connector) + .connect(domain_name, socket) + .await + .map_err(|_| Error::Connection) + } + + #[instrument] + pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { + TcpStream::connect(socket_addr) + .await + .map_err(|_| Error::Connection) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test(tokio::test)] + async fn connect() { + Connection::connect("blos.sm").await.unwrap(); + } + + // #[test(tokio::test)] + // async fn test_tls() { + // Connection::connect("blos.sm", None, None) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap(); + // } +} |