summaryrefslogblamecommitdiffstats
path: root/src/connection.rs
blob: 24f7745cc74536fd7e2f7943d65858723222fe54 (plain) (tree)

































































































































































                                                                                                    
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;

use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls
use tokio_native_tls::TlsStream;

use crate::Jabber;
use crate::JabberError;
use crate::Result;

pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream;

pub enum Connection {
    Encrypted(Jabber<Tls>),
    Unencrypted(Jabber<Unencrypted>),
}

impl Connection {
    pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
        match self {
            Connection::Encrypted(j) => Ok(j),
            Connection::Unencrypted(j) => Ok(j.starttls().await?),
        }
    }

    // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> {
    //     let server = jid.domainpart.clone();
    //     let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
    //     println!("auth: {:?}", auth);
    //     Self::connect(&server, jid.try_into()?, Some(auth)).await
    // }

    async fn connect(server: &str) -> Result<Self> {
        let sockets = Self::get_sockets(&server).await;
        for (socket_addr, tls) in sockets {
            match tls {
                true => {
                    if let Ok(connection) = Self::connect_tls(socket_addr, &server).await {
                        let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Encrypted(Jabber::new(
                            readhalf,
                            writehalf,
                            None,
                            None,
                            server.to_owned(),
                        )));
                    }
                }
                false => {
                    if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
                        let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Unencrypted(Jabber::new(
                            readhalf,
                            writehalf,
                            None,
                            None,
                            server.to_owned(),
                        )));
                    }
                }
            }
        }
        Err(JabberError::Connection)
    }

    async fn get_sockets(domain: &str) -> Vec<(SocketAddr, bool)> {
        let mut socket_addrs = Vec::new();

        // if it's a socket/ip then just return that

        // socket
        if let Ok(socket_addr) = SocketAddr::from_str(domain) {
            match socket_addr.port() {
                5223 => socket_addrs.push((socket_addr, true)),
                _ => socket_addrs.push((socket_addr, false)),
            }

            return socket_addrs;
        }
        // ip
        if let Ok(ip) = IpAddr::from_str(domain) {
            socket_addrs.push((SocketAddr::new(ip, 5222), false));
            socket_addrs.push((SocketAddr::new(ip, 5223), true));
            return socket_addrs;
        }

        // otherwise resolve
        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpp-client._tcp.{}", domain))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
                            }
                        });
                }
            }
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpps-client._tcp.{}", domain))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
                            }
                        });
                }
            }

            // in case cannot connect through SRV records
            resolver.lookup_ip(domain).await.map(|ips| {
                for ip in ips {
                    socket_addrs.push((SocketAddr::new(ip, 5222), false));
                    socket_addrs.push((SocketAddr::new(ip, 5223), true));
                }
            });
        }
        socket_addrs
    }

    /// establishes a connection to the server
    pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
        let socket = TcpStream::connect(socket_addr)
            .await
            .map_err(|_| JabberError::Connection)?;
        let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?;
        tokio_native_tls::TlsConnector::from(connector)
            .connect(domain_name, socket)
            .await
            .map_err(|_| JabberError::Connection)
    }

    pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
        TcpStream::connect(socket_addr)
            .await
            .map_err(|_| JabberError::Connection)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn connect() {
        Connection::connect("blos.sm").await.unwrap();
    }
}