aboutsummaryrefslogblamecommitdiffstats
path: root/src/jabber.rs
blob: 599879dab0dc062d0095257004490b543b292273 (plain) (tree)
1
2
3
4
5
6
7
8
9
             
                   
 
                                     
                                                 
                              
                                                       
                                                                                                 
                         
                                               
                                              
                                                     
 
                                          
                        
                                                                                 
                                                 
                                                       
                               
               
                                
 
                      

                                 


                                  

 










                                      

                                         


                   
                
                 
                   
         
     
 
 

                 
                                             
                               
 













































                                                                            



                                                                                               

















                                                                                      


                                                     
                                                         
                                           






                                                                                



                                                                                                   








                                                
 
                 

                                                        
 
                      
                                                          
 
                                 

                                                                              
                                                
 
                           
 
                                            
                                                    
 
                                              

                                                                      
                                                             
                                           

                                          
         

              
     
 










                                                                  

                          
















































































                                                                                                












                                                                                
                                     





                                       



                                          
     
 






























                                                                        
                                                                                   




                                                                              




































                                                                                       
 
use std::str;
use std::sync::Arc;

use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf};
use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace};
use trust_dns_resolver::proto::rr::domain::IntoLabel;

use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream};
use crate::stanza::XML_VERSION;
use crate::JID;
use crate::{Connection, Result};

pub struct Jabber<S> {
    reader: Reader<ReadHalf<S>>,
    writer: Writer<WriteHalf<S>>,
    jid: Option<JID>,
    auth: Option<Arc<SASLConfig>>,
    server: String,
}

impl<S> Jabber<S>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    pub fn new(
        reader: ReadHalf<S>,
        writer: WriteHalf<S>,
        jid: Option<JID>,
        auth: Option<Arc<SASLConfig>>,
        server: String,
    ) -> Self {
        let reader = Reader::new(reader);
        let writer = Writer::new(writer);
        Self {
            reader,
            writer,
            jid,
            auth,
            server,
        }
    }
}

impl<S> Jabber<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send,
    Jabber<S>: std::fmt::Debug,
{
    pub async fn sasl(
        &mut self,
        mechanisms: Mechanisms,
        sasl_config: Arc<SASLConfig>,
    ) -> Result<()> {
        let sasl = SASLClient::new(sasl_config);
        let mut offered_mechs: Vec<&Mechname> = Vec::new();
        for mechanism in &mechanisms.mechanisms {
            offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
        }
        debug!("{:?}", offered_mechs);
        let mut session = sasl.start_suggested(&offered_mechs)?;
        let selected_mechanism = session.get_mechname().as_str().to_owned();
        debug!("selected mech: {:?}", selected_mechanism);
        let mut data: Option<Vec<u8>> = None;

        if !session.are_we_first() {
            // if not first mention the mechanism then get challenge data
            // mention mechanism
            let auth = Auth {
                mechanism: selected_mechanism,
                sasl_data: "=".to_string(),
            };
            self.writer.write_full(&auth).await?;
            // get challenge data
            let challenge: Challenge = self.reader.read().await?;
            debug!("challenge: {:?}", challenge);
            data = Some((*challenge).as_bytes().to_vec());
            debug!("we didn't go first");
        } else {
            // if first, mention mechanism and send data
            let mut sasl_data = Vec::new();
            session.step64(None, &mut sasl_data).unwrap();
            let auth = Auth {
                mechanism: selected_mechanism,
                sasl_data: str::from_utf8(&sasl_data)?.to_string(),
            };
            debug!("{:?}", auth);
            self.writer.write_full(&auth).await?;

            let server_response: ServerResponse = self.reader.read().await?;
            debug!("server_response: {:#?}", server_response);
            match server_response {
                ServerResponse::Challenge(challenge) => {
                    data = Some((*challenge).as_bytes().to_vec())
                }
                ServerResponse::Success(success) => {
                    data = success.clone().map(|success| success.as_bytes().to_vec())
                }
                ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
            }
            debug!("we went first");
        }

        // stepping the authentication exchange to completion
        if data != None {
            debug!("data: {:?}", data);
            let mut sasl_data = Vec::new();
            while {
                // decide if need to send more data over
                let state = session
                    .step64(data.as_deref(), &mut sasl_data)
                    .expect("step errored!");
                state.is_running()
            } {
                // While we aren't finished, receive more data from the other party
                let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
                debug!("response: {:?}", response);
                let stdout = tokio::io::stdout();
                let mut writer = Writer::new(stdout);
                writer.write_full(&response).await?;
                self.writer.write_full(&response).await?;
                debug!("response written");

                let server_response: ServerResponse = self.reader.read().await?;
                debug!("server_response: {:#?}", server_response);
                match server_response {
                    ServerResponse::Challenge(challenge) => {
                        data = Some((*challenge).as_bytes().to_vec())
                    }
                    ServerResponse::Success(success) => {
                        data = success.clone().map(|success| success.as_bytes().to_vec())
                    }
                    ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
                }
            }
        }
        Ok(())
    }

    pub async fn bind(&mut self) -> Result<()> {
        todo!()
    }

    #[instrument]
    pub async fn start_stream(&mut self) -> Result<()> {
        // client to server

        // declaration
        self.writer.write_declaration(XML_VERSION).await?;

        // opening stream element
        let server = self.server.clone().try_into()?;
        let stream = Stream::new_client(None, server, None, "en".to_string());
        self.writer.write_start(&stream).await?;

        // server to client

        // may or may not send a declaration
        let decl = self.reader.read_prolog().await?;

        // receive stream element and validate
        let text = str::from_utf8(self.reader.buffer.data()).unwrap();
        debug!("data: {}", text);
        let stream: Stream = self.reader.read_start().await?;
        debug!("got stream: {:?}", stream);
        if let Some(from) = stream.from {
            self.server = from.to_string()
        }

        Ok(())
    }

    pub async fn get_features(&mut self) -> Result<Features> {
        debug!("getting features");
        let features: Features = self.reader.read().await?;
        debug!("got features: {:?}", features);
        Ok(features)
    }

    pub fn into_inner(self) -> S {
        self.reader.into_inner().unsplit(self.writer.into_inner())
    }
}

impl Jabber<Unencrypted> {
    pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> {
        self.start_stream().await?;
        // TODO: timeout
        let features = self.get_features().await?.features;
        if let Some(Feature::StartTls(_)) = features
            .iter()
            .find(|feature| matches!(feature, Feature::StartTls(_s)))
        {
            let jabber = self.starttls().await?;
            let jabber = jabber.negotiate().await?;
            return Ok(jabber);
        } else {
            // TODO: better error
            return Err(Error::TlsRequired);
        }
    }

    #[async_recursion]
    pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
        self.start_stream().await?;
        // TODO: timeout
        let features = self.get_features().await?.features;
        if let Some(Feature::StartTls(_)) = features
            .iter()
            .find(|feature| matches!(feature, Feature::StartTls(_s)))
        {
            let jabber = self.starttls().await?;
            let jabber = jabber.negotiate().await?;
            return Ok(Connection::Encrypted(jabber));
        } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
            self.auth.clone(),
            features
                .iter()
                .find(|feature| matches!(feature, Feature::Sasl(_))),
        ) {
            self.sasl(mechanisms.clone(), sasl_config).await?;
            let jabber = self.negotiate_tls_optional().await?;
            Ok(jabber)
        } else if let Some(Feature::Bind) = features
            .iter()
            .find(|feature| matches!(feature, Feature::Bind))
        {
            self.bind().await?;
            Ok(Connection::Unencrypted(self))
        } else {
            // TODO: better error
            return Err(Error::Negotiation);
        }
    }
}

impl Jabber<Tls> {
    #[async_recursion]
    pub async fn negotiate(mut self) -> Result<Jabber<Tls>> {
        self.start_stream().await?;
        let features = self.get_features().await?.features;

        if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
            self.auth.clone(),
            features
                .iter()
                .find(|feature| matches!(feature, Feature::Sasl(_))),
        ) {
            // TODO: avoid clone
            self.sasl(mechanisms.clone(), sasl_config).await?;
            let jabber = self.negotiate().await?;
            Ok(jabber)
        } else if let Some(Feature::Bind) = features
            .iter()
            .find(|feature| matches!(feature, Feature::Bind))
        {
            self.bind().await?;
            Ok(self)
        } else {
            // TODO: better error
            return Err(Error::Negotiation);
        }
    }
}

impl Jabber<Unencrypted> {
    pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
        self.writer
            .write_full(&StartTls { required: false })
            .await?;
        let proceed: Proceed = self.reader.read().await?;
        debug!("got proceed: {:?}", proceed);
        let connector = TlsConnector::new().unwrap();
        let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
        if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
            .connect(&self.server, stream)
            .await
        {
            let (read, write) = tokio::io::split(tlsstream);
            let client = Jabber::new(
                read,
                write,
                self.jid.to_owned(),
                self.auth.to_owned(),
                self.server.to_owned(),
            );
            return Ok(client);
        } else {
            return Err(Error::Connection);
        }
    }
}

impl std::fmt::Debug for Jabber<Tls> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"tls")
            .field("jid", &self.jid)
            .field("auth", &self.auth)
            .field("server", &self.server)
            .finish()
    }
}

impl std::fmt::Debug for Jabber<Unencrypted> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"unencrypted")
            .field("jid", &self.jid)
            .field("auth", &self.auth)
            .field("server", &self.server)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::connection::Connection;
    use test_log::test;

    #[test(tokio::test)]
    async fn start_stream() {
        let connection = Connection::connect("blos.sm", None, None).await.unwrap();
        match connection {
            Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
            Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
        }
    }

    #[test(tokio::test)]
    async fn sasl() {
        let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
            .await
            .unwrap()
            .ensure_tls()
            .await
            .unwrap();
        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        println!("data: {}", text);
        jabber.start_stream().await.unwrap();

        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        println!("data: {}", text);
        jabber.reader.read_buf().await.unwrap();
        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        println!("data: {}", text);

        let features = jabber.get_features().await.unwrap();
        let (sasl_config, feature) = (
            jabber.auth.clone().unwrap(),
            features
                .features
                .iter()
                .find(|feature| matches!(feature, Feature::Sasl(_)))
                .unwrap(),
        );
        match feature {
            Feature::StartTls(_start_tls) => todo!(),
            Feature::Sasl(mechanisms) => {
                jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
            }
            Feature::Bind => todo!(),
            Feature::Unknown => todo!(),
        }
    }
}