summaryrefslogblamecommitdiffstats
path: root/src/jabber.rs
blob: a1f62727c1f7d9780d215680125b4163bac5fef1 (plain) (tree)


































































































































                                                                                           
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;

use quick_xml::{Reader, Writer};
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;

use crate::client;
use crate::client::JabberClientType;
use crate::jid::JID;
use crate::{JabberError, Result};

pub struct Jabber<'j> {
    pub jid: JID,
    pub password: String,
    pub server: String,
    _marker: PhantomData<&'j ()>,
}

impl<'j> Jabber<'j> {
    pub fn new(jid: JID, password: String) -> Self {
        let server = jid.domainpart.clone();
        Self {
            jid,
            password,
            server,
            _marker: PhantomData,
        }
    }

    async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
        let mut socket_addrs = Vec::new();

        // if it's a socket/ip then just return that

        // socket
        if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
            match socket_addr.port() {
                5223 => socket_addrs.push((socket_addr, true)),
                _ => socket_addrs.push((socket_addr, false)),
            }

            return socket_addrs;
        }
        // ip
        if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
            socket_addrs.push((SocketAddr::new(ip, 5222), false));
            socket_addrs.push((SocketAddr::new(ip, 5223), true));
            return socket_addrs;
        }

        // otherwise resolve
        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
                            }
                        });
                }
            }
            if let Ok(lookup) = resolver
                .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
                .await
            {
                for srv in lookup {
                    resolver
                        .lookup_ip(srv.target().to_owned())
                        .await
                        .map(|ips| {
                            for ip in ips {
                                socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
                            }
                        });
                }
            }

            // in case cannot connect through SRV records
            resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
                for ip in ips {
                    socket_addrs.push((SocketAddr::new(ip, 5222), false));
                    socket_addrs.push((SocketAddr::new(ip, 5223), true));
                }
            });
        }
        socket_addrs
    }

    pub async fn connect(&'j mut self) -> Result<JabberClientType> {
        for (socket_addr, is_tls) in self.get_sockets().await {
            println!("trying {}", socket_addr);
            match is_tls {
                true => {
                    let socket = TcpStream::connect(socket_addr).await.unwrap();
                    let connector = TlsConnector::new().unwrap();
                    if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector)
                        .connect(&self.server, socket)
                        .await
                    {
                        let (read, write) = tokio::io::split(stream);
                        let reader = Reader::from_reader(BufReader::new(read));
                        let writer = Writer::new(write);
                        return Ok(JabberClientType::Encrypted(
                            client::encrypted::JabberClient::new(reader, writer, self),
                        ));
                    }
                }
                false => {
                    if let Ok(stream) = TcpStream::connect(socket_addr).await {
                        let (read, write) = tokio::io::split(stream);
                        let reader = Reader::from_reader(BufReader::new(read));
                        let writer = Writer::new(write);
                        return Ok(JabberClientType::Unencrypted(
                            client::unencrypted::JabberClient::new(reader, writer, self),
                        ));
                    }
                }
            }
        }
        Err(JabberError::ConnectionError)
    }
}