aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/connection.rs
diff options
context:
space:
mode:
Diffstat (limited to 'jabber/src/connection.rs')
-rw-r--r--jabber/src/connection.rs184
1 files changed, 184 insertions, 0 deletions
diff --git a/jabber/src/connection.rs b/jabber/src/connection.rs
new file mode 100644
index 0000000..bc5a282
--- /dev/null
+++ b/jabber/src/connection.rs
@@ -0,0 +1,184 @@
+use std::net::{IpAddr, SocketAddr};
+use std::str;
+use std::str::FromStr;
+use std::sync::Arc;
+
+use rsasl::config::SASLConfig;
+use tokio::net::TcpStream;
+use tokio_native_tls::native_tls::TlsConnector;
+// TODO: use rustls
+use tokio_native_tls::TlsStream;
+use tracing::{debug, info, instrument, trace};
+
+use crate::Result;
+use crate::{Error, JID};
+
+pub type Tls = TlsStream<TcpStream>;
+pub type Unencrypted = TcpStream;
+
+#[derive(Debug)]
+pub enum Connection {
+ Encrypted(Tls),
+ Unencrypted(Unencrypted),
+}
+
+impl Connection {
+ // #[instrument]
+ /// stream not started
+ // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
+ // match self {
+ // Connection::Encrypted(j) => Ok(j),
+ // Connection::Unencrypted(mut j) => {
+ // j.start_stream().await?;
+ // info!("upgrading connection to tls");
+ // j.get_features().await?;
+ // let j = j.starttls().await?;
+ // Ok(j)
+ // }
+ // }
+ // }
+
+ pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> {
+ let jid: JID = JID::from_str(jid.as_ref())?;
+ let server = jid.domainpart.clone();
+ Self::connect(&server).await
+ }
+
+ #[instrument]
+ pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> {
+ info!("connecting to {}", server.as_ref());
+ let sockets = Self::get_sockets(server.as_ref()).await;
+ debug!("discovered sockets: {:?}", sockets);
+ for (socket_addr, tls) in sockets {
+ match tls {
+ true => {
+ if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await {
+ info!("connected via encrypted stream to {}", socket_addr);
+ // let (readhalf, writehalf) = tokio::io::split(connection);
+ return Ok(Self::Encrypted(connection));
+ }
+ }
+ false => {
+ if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
+ info!("connected via unencrypted stream to {}", socket_addr);
+ // let (readhalf, writehalf) = tokio::io::split(connection);
+ return Ok(Self::Unencrypted(connection));
+ }
+ }
+ }
+ }
+ Err(Error::Connection)
+ }
+
+ #[instrument]
+ async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
+ let mut socket_addrs = Vec::new();
+
+ // if it's a socket/ip then just return that
+
+ // socket
+ trace!("checking if address is a socket address");
+ if let Ok(socket_addr) = SocketAddr::from_str(address) {
+ debug!("{} is a socket address", address);
+ match socket_addr.port() {
+ 5223 => socket_addrs.push((socket_addr, true)),
+ _ => socket_addrs.push((socket_addr, false)),
+ }
+
+ return socket_addrs;
+ }
+ // ip
+ trace!("checking if address is an ip");
+ if let Ok(ip) = IpAddr::from_str(address) {
+ debug!("{} is an ip", address);
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ return socket_addrs;
+ }
+
+ // otherwise resolve
+ debug!("resolving {}", address);
+ if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpp-client._tcp.{}", address))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
+ }
+ });
+ }
+ }
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpps-client._tcp.{}", address))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
+ }
+ });
+ }
+ }
+
+ // in case cannot connect through SRV records
+ resolver.lookup_ip(address).await.map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ }
+ });
+ }
+ socket_addrs
+ }
+
+ /// establishes a connection to the server
+ #[instrument]
+ pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
+ let socket = TcpStream::connect(socket_addr)
+ .await
+ .map_err(|_| Error::Connection)?;
+ let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
+ tokio_native_tls::TlsConnector::from(connector)
+ .connect(domain_name, socket)
+ .await
+ .map_err(|_| Error::Connection)
+ }
+
+ #[instrument]
+ pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
+ TcpStream::connect(socket_addr)
+ .await
+ .map_err(|_| Error::Connection)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use test_log::test;
+
+ #[test(tokio::test)]
+ async fn connect() {
+ Connection::connect("blos.sm").await.unwrap();
+ }
+
+ // #[test(tokio::test)]
+ // async fn test_tls() {
+ // Connection::connect("blos.sm", None, None)
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap();
+ // }
+}