aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/connection.rs
blob: b185eca82f7d542dd065df3144b7e2c844de656c (plain) (tree)
1
2
3
4
5
6
7
8
9







                                               
                                              
 
                  
                        



                                    
                
                     

                             


                 
                    
                          













                                                                     

                                                    
                                    
     
 
                 


                                                                                     
                                                    


                                           
                                                                                                   
                                                                                   

                                                                                    



                                                                                          
                                                                                     

                                                                                    



                     
                              

     

                                                                    




                                                    


                                                                







                                                               


                                                   





                                                                  
                                        

                                                                                           
                                                                     













                                                                                           
                                                                      














                                                                                          
                                                         









                                                                          
                 


                                                                                         

                                                                            


                                                       
                                           

     
                 


                                                                                      
                                           





                 
                       
 
                        
                        
                                                      
     
 








                                                     
 
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;

use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls
use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace};

use crate::Result;
use crate::{Error, JID};

pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream;

#[derive(Debug)]
pub enum Connection {
    Encrypted(Tls),
    Unencrypted(Unencrypted),
}

impl Connection {
    // #[instrument]
    /// stream not started
    // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
    //     match self {
    //         Connection::Encrypted(j) => Ok(j),
    //         Connection::Unencrypted(mut j) => {
    //             j.start_stream().await?;
    //             info!("upgrading connection to tls");
    //             j.get_features().await?;
    //             let j = j.starttls().await?;
    //             Ok(j)
    //         }
    //     }
    // }

    pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> {
        let jid: JID = JID::from_str(jid.as_ref())?;
        let server = jid.domainpart.clone();
        Self::connect(&server).await
    }

    #[instrument]
    pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> {
        info!("connecting to {}", server.as_ref());
        let sockets = Self::get_sockets(server.as_ref()).await;
        debug!("discovered sockets: {:?}", sockets);
        for (socket_addr, tls) in sockets {
            match tls {
                true => {
                    if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await {
                        info!("connected via encrypted stream to {}", socket_addr);
                        // let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Encrypted(connection));
                    }
                }
                false => {
                    if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
                        info!("connected via unencrypted stream to {}", socket_addr);
                        // let (readhalf, writehalf) = tokio::io::split(connection);
                        return Ok(Self::Unencrypted(connection));
                    }
                }
            }
        }
        Err(Error::Connection)
    }

    #[instrument]
    async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
        let mut socket_addrs = Vec::new();

        // if it's a socket/ip then just return that

        // socket
        trace!("checking if address is a socket address");
        if let Ok(socket_addr) = SocketAddr::from_str(address) {
            debug!("{} is a socket address", address);
            match socket_addr.port() {
                5223 => socket_addrs.push((socket_addr, true)),
                _ => socket_addrs.push((socket_addr, false)),
            }

            return socket_addrs;
        }
        // ip
        trace!("checking if address is an ip");
        if let Ok(ip) = IpAddr::from_str(address) {
            debug!("{} is an ip", address);
            socket_addrs.push((SocketAddr::new(ip, 5222), false));
            socket_addrs.push((SocketAddr::new(ip, 5223), true));
            return socket_addrs;
        }

        // otherwise resolve
        debug!("resolving {}", address);
        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpp-client._tcp.{}", address))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
                            }
                        });
                }
            }
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpps-client._tcp.{}", address))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
                            }
                        });
                }
            }

            // in case cannot connect through SRV records
            resolver.lookup_ip(address).await.map(|ips| {
                for ip in ips {
                    socket_addrs.push((SocketAddr::new(ip, 5222), false));
                    socket_addrs.push((SocketAddr::new(ip, 5223), true));
                }
            });
        }
        socket_addrs
    }

    /// establishes a connection to the server
    #[instrument]
    pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
        let socket = TcpStream::connect(socket_addr)
            .await
            .map_err(|_| Error::Connection)?;
        let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
        tokio_native_tls::TlsConnector::from(connector)
            .connect(domain_name, socket)
            .await
            .map_err(|_| Error::Connection)
    }

    #[instrument]
    pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
        TcpStream::connect(socket_addr)
            .await
            .map_err(|_| Error::Connection)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use test_log::test;

    #[test(tokio::test)]
    async fn connect() {
        Connection::connect("blos.sm").await.unwrap();
    }

    // #[test(tokio::test)]
    // async fn test_tls() {
    //     Connection::connect("blos.sm", None, None)
    //         .await
    //         .unwrap()
    //         .ensure_tls()
    //         .await
    //         .unwrap();
    // }
}