use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls
use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace};
use crate::Result;
use crate::{Error, JID};
pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream;
#[derive(Debug)]
pub enum Connection {
Encrypted(Tls),
Unencrypted(Unencrypted),
}
impl Connection {
// #[instrument]
/// stream not started
// pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
// match self {
// Connection::Encrypted(j) => Ok(j),
// Connection::Unencrypted(mut j) => {
// j.start_stream().await?;
// info!("upgrading connection to tls");
// j.get_features().await?;
// let j = j.starttls().await?;
// Ok(j)
// }
// }
// }
pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> {
let jid: JID = JID::from_str(jid.as_ref())?;
let server = jid.domainpart.clone();
Self::connect(&server).await
}
#[instrument]
pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> {
info!("connecting to {}", server.as_ref());
let sockets = Self::get_sockets(server.as_ref()).await;
debug!("discovered sockets: {:?}", sockets);
for (socket_addr, tls) in sockets {
match tls {
true => {
if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await {
info!("connected via encrypted stream to {}", socket_addr);
// let (readhalf, writehalf) = tokio::io::split(connection);
return Ok(Self::Encrypted(connection));
}
}
false => {
if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
info!("connected via unencrypted stream to {}", socket_addr);
// let (readhalf, writehalf) = tokio::io::split(connection);
return Ok(Self::Unencrypted(connection));
}
}
}
}
Err(Error::Connection)
}
#[instrument]
async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
let mut socket_addrs = Vec::new();
// if it's a socket/ip then just return that
// socket
trace!("checking if address is a socket address");
if let Ok(socket_addr) = SocketAddr::from_str(address) {
debug!("{} is a socket address", address);
match socket_addr.port() {
5223 => socket_addrs.push((socket_addr, true)),
_ => socket_addrs.push((socket_addr, false)),
}
return socket_addrs;
}
// ip
trace!("checking if address is an ip");
if let Ok(ip) = IpAddr::from_str(address) {
debug!("{} is an ip", address);
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
return socket_addrs;
}
// otherwise resolve
debug!("resolving {}", address);
if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpp-client._tcp.{}", address))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
}
});
}
}
if let Ok(lookup) = resolver
.srv_lookup(format!("_xmpps-client._tcp.{}", address))
.await
{
for srv in lookup {
resolver
.lookup_ip(srv.target().to_owned())
.await
.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
}
});
}
}
// in case cannot connect through SRV records
resolver.lookup_ip(address).await.map(|ips| {
for ip in ips {
socket_addrs.push((SocketAddr::new(ip, 5222), false));
socket_addrs.push((SocketAddr::new(ip, 5223), true));
}
});
}
socket_addrs
}
/// establishes a connection to the server
#[instrument]
pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
let socket = TcpStream::connect(socket_addr)
.await
.map_err(|_| Error::Connection)?;
let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
tokio_native_tls::TlsConnector::from(connector)
.connect(domain_name, socket)
.await
.map_err(|_| Error::Connection)
}
#[instrument]
pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
TcpStream::connect(socket_addr)
.await
.map_err(|_| Error::Connection)
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
#[test(tokio::test)]
async fn connect() {
Connection::connect("blos.sm").await.unwrap();
}
// #[test(tokio::test)]
// async fn test_tls() {
// Connection::connect("blos.sm", None, None)
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap();
// }
}