use std::net::{IpAddr, SocketAddr}; use std::str; use std::str::FromStr; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; // TODO: use rustls use tokio_native_tls::TlsStream; use tracing::{debug, info, instrument, trace}; use crate::Result; use crate::{Error, JID}; pub type Tls = TlsStream; pub type Unencrypted = TcpStream; #[derive(Debug)] pub enum Connection { Encrypted(Tls), Unencrypted(Unencrypted), } impl Connection { // #[instrument] /// stream not started // pub async fn ensure_tls(self) -> Result> { // match self { // Connection::Encrypted(j) => Ok(j), // Connection::Unencrypted(mut j) => { // j.start_stream().await?; // info!("upgrading connection to tls"); // j.get_features().await?; // let j = j.starttls().await?; // Ok(j) // } // } // } pub async fn connect_user(jid: impl AsRef) -> Result { let jid: JID = JID::from_str(jid.as_ref())?; let server = jid.domainpart.clone(); Self::connect(&server).await } #[instrument] pub async fn connect(server: impl AsRef + std::fmt::Debug) -> Result { info!("connecting to {}", server.as_ref()); let sockets = Self::get_sockets(server.as_ref()).await; debug!("discovered sockets: {:?}", sockets); for (socket_addr, tls) in sockets { match tls { true => { if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await { info!("connected via encrypted stream to {}", socket_addr); // let (readhalf, writehalf) = tokio::io::split(connection); return Ok(Self::Encrypted(connection)); } } false => { if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { info!("connected via unencrypted stream to {}", socket_addr); // let (readhalf, writehalf) = tokio::io::split(connection); return Ok(Self::Unencrypted(connection)); } } } } Err(Error::Connection) } #[instrument] async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> { let mut socket_addrs = Vec::new(); // if it's a socket/ip then just return that // socket trace!("checking if address is a socket address"); if let Ok(socket_addr) = SocketAddr::from_str(address) { debug!("{} is a socket address", address); match socket_addr.port() { 5223 => socket_addrs.push((socket_addr, true)), _ => socket_addrs.push((socket_addr, false)), } return socket_addrs; } // ip trace!("checking if address is an ip"); if let Ok(ip) = IpAddr::from_str(address) { debug!("{} is an ip", address); socket_addrs.push((SocketAddr::new(ip, 5222), false)); socket_addrs.push((SocketAddr::new(ip, 5223), true)); return socket_addrs; } // otherwise resolve debug!("resolving {}", address); if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { if let Ok(lookup) = resolver .srv_lookup(format!("_xmpp-client._tcp.{}", address)) .await { for srv in lookup { resolver .lookup_ip(srv.target().to_owned()) .await .map(|ips| { for ip in ips { socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) } }); } } if let Ok(lookup) = resolver .srv_lookup(format!("_xmpps-client._tcp.{}", address)) .await { for srv in lookup { resolver .lookup_ip(srv.target().to_owned()) .await .map(|ips| { for ip in ips { socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) } }); } } // in case cannot connect through SRV records resolver.lookup_ip(address).await.map(|ips| { for ip in ips { socket_addrs.push((SocketAddr::new(ip, 5222), false)); socket_addrs.push((SocketAddr::new(ip, 5223), true)); } }); } socket_addrs } /// establishes a connection to the server #[instrument] pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result { let socket = TcpStream::connect(socket_addr) .await .map_err(|_| Error::Connection)?; let connector = TlsConnector::new().map_err(|_| Error::Connection)?; tokio_native_tls::TlsConnector::from(connector) .connect(domain_name, socket) .await .map_err(|_| Error::Connection) } #[instrument] pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result { TcpStream::connect(socket_addr) .await .map_err(|_| Error::Connection) } } #[cfg(test)] mod tests { use super::*; use test_log::test; #[test(tokio::test)] async fn connect() { Connection::connect("blos.sm").await.unwrap(); } // #[test(tokio::test)] // async fn test_tls() { // Connection::connect("blos.sm", None, None) // .await // .unwrap() // .ensure_tls() // .await // .unwrap(); // } }