summaryrefslogblamecommitdiffstats
path: root/src/connection.rs
blob: b42711eedf23104a0bb0560578f04ddf1ca66b46 (plain) (tree)
1
2
3
4
5
6
7
8
9







                                               
                                              







                                    
                





                                     
                 


                                                          



                                                     









                                                                                                    
                 
                                                        
                                          
                                                       
                                                    



                                                                                           
                                                                                   











                                                                                          
                                                                                     














                                                                                 

                                                                    




                                                    


                                                                







                                                               


                                                   





                                                                  
                                        

                                                                                           
                                                                     













                                                                                           
                                                                      














                                                                                          
                                                         









                                                                          
                 










                                                                                         
                 









                                                                                      
                       
 
                        



                                                      
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;

use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls
use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace};

use crate::Jabber;
use crate::JabberError;
use crate::Result;

pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream;

#[derive(Debug)]
pub enum Connection {
    Encrypted(Jabber<Tls>),
    Unencrypted(Jabber<Unencrypted>),
}

impl Connection {
    #[instrument]
    pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
        match self {
            Connection::Encrypted(j) => Ok(j),
            Connection::Unencrypted(mut j) => {
                info!("upgrading connection to tls");
                Ok(j.starttls().await?)
            }
        }
    }

    // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> {
    //     let server = jid.domainpart.clone();
    //     let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
    //     println!("auth: {:?}", auth);
    //     Self::connect(&server, jid.try_into()?, Some(auth)).await
    // }

    #[instrument]
    pub async fn connect(server: &str) -> Result<Self> {
        info!("connecting to {}", server);
        let sockets = Self::get_sockets(&server).await;
        debug!("discovered sockets: {:?}", sockets);
        for (socket_addr, tls) in sockets {
            match tls {
                true => {
                    if let Ok(connection) = Self::connect_tls(socket_addr, &server).await {
                        info!("connected via encrypted stream to {}", socket_addr);
                        let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Encrypted(Jabber::new(
                            readhalf,
                            writehalf,
                            None,
                            None,
                            server.to_owned(),
                        )));
                    }
                }
                false => {
                    if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
                        info!("connected via unencrypted stream to {}", socket_addr);
                        let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Unencrypted(Jabber::new(
                            readhalf,
                            writehalf,
                            None,
                            None,
                            server.to_owned(),
                        )));
                    }
                }
            }
        }
        Err(JabberError::Connection)
    }

    #[instrument]
    async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
        let mut socket_addrs = Vec::new();

        // if it's a socket/ip then just return that

        // socket
        trace!("checking if address is a socket address");
        if let Ok(socket_addr) = SocketAddr::from_str(address) {
            debug!("{} is a socket address", address);
            match socket_addr.port() {
                5223 => socket_addrs.push((socket_addr, true)),
                _ => socket_addrs.push((socket_addr, false)),
            }

            return socket_addrs;
        }
        // ip
        trace!("checking if address is an ip");
        if let Ok(ip) = IpAddr::from_str(address) {
            debug!("{} is an ip", address);
            socket_addrs.push((SocketAddr::new(ip, 5222), false));
            socket_addrs.push((SocketAddr::new(ip, 5223), true));
            return socket_addrs;
        }

        // otherwise resolve
        debug!("resolving {}", address);
        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpp-client._tcp.{}", address))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
                            }
                        });
                }
            }
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpps-client._tcp.{}", address))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
                            }
                        });
                }
            }

            // in case cannot connect through SRV records
            resolver.lookup_ip(address).await.map(|ips| {
                for ip in ips {
                    socket_addrs.push((SocketAddr::new(ip, 5222), false));
                    socket_addrs.push((SocketAddr::new(ip, 5223), true));
                }
            });
        }
        socket_addrs
    }

    /// establishes a connection to the server
    #[instrument]
    pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
        let socket = TcpStream::connect(socket_addr)
            .await
            .map_err(|_| JabberError::Connection)?;
        let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?;
        tokio_native_tls::TlsConnector::from(connector)
            .connect(domain_name, socket)
            .await
            .map_err(|_| JabberError::Connection)
    }

    #[instrument]
    pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
        TcpStream::connect(socket_addr)
            .await
            .map_err(|_| JabberError::Connection)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use test_log::test;

    #[test(tokio::test)]
    async fn connect() {
        Connection::connect("blos.sm").await.unwrap();
    }
}