aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/jabber_stream.rs
blob: 302350d1168d454ea01351506f444a9cd1d86770 (plain) (tree)
1
2
3
4
5
6
7
                              
                   
 
             
                                  
                              
                                                       






                                                                          
                                                            
                                               
                                 
 
                                          
                        
                  
 

                     

                                
































































































                                                                                         

 
                       
     

                                                               
 
                 
                                                                                                    


                                                           

                                                                                                

                                      


                                                 

                                                                            
                                      






























                                                                            


                                                                                     
                                                                                            


















                                                                                      
                                           






                                                                                


                                                                                         
                                                                                                


                 

                                                           

                                            

     
                 
                                                                
                                      
                                                   






                                              
                                                                                         













                                                       
                                                                              


                                     

                                    




































                                                                                 
                                                                              


                                     

                                    
















                                                                                 
     
 
                 
                                                                                   
                           
                                                            

                                                                
 
                      
                                                     
 
                                 






                                            
 
                           
 
                                            
                                                
 
                                              
                                                        
                                           
                                         
                                       
         
 
                                   
     
 
                 
                                                                     


                                                           
                            


                                  



                                                           
     
 



                                                                        

 
                                

                                                                                               





                                                         




                                                            

                                                                               

                  
                                  


                                          
     
 
 
                                            


                                                                        



                     
                                                    


                                                                        





                     



                             




                                                                                      
     


                        
































                                                                                          
     

                  
                     

                                                                                 





                                                                                    









                                                                                       
     
 
use std::str::{self, FromStr};
use std::sync::Arc;

use jid::JID;
use peanuts::element::IntoElement;
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use stanza::client::iq::{Iq, IqType, Query};
use stanza::client::Stanza;
use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use stanza::starttls::{Proceed, StartTls};
use stanza::stream::{Features, Stream};
use stanza::XML_VERSION;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, instrument};

use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::Result;

pub mod bound_stream;

// open stream (streams started)
pub struct JabberStream<S> {
    reader: JabberReader<S>,
    writer: JabberWriter<S>,
}

impl<S> JabberStream<S> {
    fn split(self) -> (JabberReader<S>, JabberWriter<S>) {
        let reader = self.reader;
        let writer = self.writer;
        (reader, writer)
    }
}

pub struct JabberReader<S>(Reader<ReadHalf<S>>);

impl<S> JabberReader<S> {
    // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner
    fn new(reader: Reader<ReadHalf<S>>) -> Self {
        Self(reader)
    }

    fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> {
        JabberStream {
            reader: self,
            writer,
        }
    }

    fn into_inner(self) -> Reader<ReadHalf<S>> {
        self.0
    }
}

impl<S> JabberReader<S>
where
    S: AsyncRead + Unpin,
{
    pub async fn try_close(&mut self) -> Result<()> {
        self.read_end_tag().await?;
        Ok(())
    }
}

impl<S> std::ops::Deref for JabberReader<S> {
    type Target = Reader<ReadHalf<S>>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<S> std::ops::DerefMut for JabberReader<S> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

pub struct JabberWriter<S>(Writer<WriteHalf<S>>);

impl<S> JabberWriter<S> {
    fn new(writer: Writer<WriteHalf<S>>) -> Self {
        Self(writer)
    }

    fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> {
        JabberStream {
            reader,
            writer: self,
        }
    }

    fn into_inner(self) -> Writer<WriteHalf<S>> {
        self.0
    }
}

impl<S> JabberWriter<S>
where
    S: AsyncWrite + Unpin + Send,
{
    pub async fn try_close(&mut self) -> Result<()> {
        self.write_end().await?;
        Ok(())
    }
}

impl<S> std::ops::Deref for JabberWriter<S> {
    type Target = Writer<WriteHalf<S>>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<S> std::ops::DerefMut for JabberWriter<S> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl<S> JabberStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
    JabberStream<S>: std::fmt::Debug,
{
    #[instrument]
    pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
        let sasl = SASLClient::new(sasl_config);
        let mut offered_mechs: Vec<&Mechname> = Vec::new();
        for mechanism in &mechanisms.mechanisms {
            offered_mechs
                .push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
        }
        debug!("{:?}", offered_mechs);
        let mut session = sasl
            .start_suggested(&offered_mechs)
            .map_err(|e| Error::SASL(e.into()))?;
        let selected_mechanism = session.get_mechname().as_str().to_owned();
        debug!("selected mech: {:?}", selected_mechanism);
        let mut data: Option<Vec<u8>>;

        if !session.are_we_first() {
            // if not first mention the mechanism then get challenge data
            // mention mechanism
            let auth = Auth {
                mechanism: selected_mechanism,
                sasl_data: "=".to_string(),
            };
            self.writer.write_full(&auth).await?;
            // get challenge data
            let challenge: Challenge = self.reader.read().await?;
            debug!("challenge: {:?}", challenge);
            data = Some((*challenge).as_bytes().to_vec());
            debug!("we didn't go first");
        } else {
            // if first, mention mechanism and send data
            let mut sasl_data = Vec::new();
            session.step64(None, &mut sasl_data).unwrap();
            let auth = Auth {
                mechanism: selected_mechanism,
                sasl_data: str::from_utf8(&sasl_data)?.to_string(),
            };
            debug!("{:?}", auth);
            self.writer.write_full(&auth).await?;

            let server_response: ServerResponse = self.reader.read().await?;
            debug!("server_response: {:#?}", server_response);
            match server_response {
                ServerResponse::Challenge(challenge) => {
                    data = Some((*challenge).as_bytes().to_vec())
                }
                ServerResponse::Success(success) => {
                    data = success.clone().map(|success| success.as_bytes().to_vec())
                }
                ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
            }
            debug!("we went first");
        }

        // stepping the authentication exchange to completion
        if data != None {
            debug!("data: {:?}", data);
            let mut sasl_data = Vec::new();
            while {
                // decide if need to send more data over
                let state = session
                    .step64(data.as_deref(), &mut sasl_data)
                    .expect("step errored!");
                state.is_running()
            } {
                // While we aren't finished, receive more data from the other party
                let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
                debug!("response: {:?}", response);
                self.writer.write_full(&response).await?;
                debug!("response written");

                let server_response: ServerResponse = self.reader.read().await?;
                debug!("server_response: {:#?}", server_response);
                match server_response {
                    ServerResponse::Challenge(challenge) => {
                        data = Some((*challenge).as_bytes().to_vec())
                    }
                    ServerResponse::Success(success) => {
                        data = success.clone().map(|success| success.as_bytes().to_vec())
                    }
                    ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
                }
            }
        }
        let writer = self.writer.into_inner().into_inner();
        let reader = self.reader.into_inner().into_inner();
        let stream = reader.unsplit(writer);
        Ok(stream)
    }

    #[instrument]
    pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
        let iq_id = nanoid::nanoid!();
        if let Some(resource) = &jid.resourcepart {
            let iq = Iq {
                from: None,
                id: iq_id.clone(),
                to: None,
                r#type: IqType::Set,
                lang: None,
                query: Some(Query::Bind(Bind {
                    r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
                })),
                errors: Vec::new(),
            };
            self.writer.write_full(&iq).await?;
            let result: Iq = self.reader.read().await?;
            match result {
                Iq {
                    from: _,
                    id,
                    to: _,
                    r#type: IqType::Result,
                    lang: _,
                    query:
                        Some(Query::Bind(Bind {
                            r#type: Some(BindType::Jid(FullJidType(new_jid))),
                        })),
                    errors: _,
                } if id == iq_id => {
                    *jid = new_jid;
                    return Ok(self);
                }
                Iq {
                    from: _,
                    id,
                    to: _,
                    r#type: IqType::Error,
                    lang: _,
                    query: None,
                    errors,
                } if id == iq_id => {
                    return Err(Error::ClientError(
                        errors.first().ok_or(Error::MissingError)?.clone(),
                    ))
                }
                _ => return Err(Error::UnexpectedElement(result.into_element())),
            }
        } else {
            let iq = Iq {
                from: None,
                id: iq_id.clone(),
                to: None,
                r#type: IqType::Set,
                lang: None,
                query: Some(Query::Bind(Bind { r#type: None })),
                errors: Vec::new(),
            };
            self.writer.write_full(&iq).await?;
            let result: Iq = self.reader.read().await?;
            match result {
                Iq {
                    from: _,
                    id,
                    to: _,
                    r#type: IqType::Result,
                    lang: _,
                    query:
                        Some(Query::Bind(Bind {
                            r#type: Some(BindType::Jid(FullJidType(new_jid))),
                        })),
                    errors: _,
                } if id == iq_id => {
                    *jid = new_jid;
                    return Ok(self);
                }
                Iq {
                    from: _,
                    id,
                    to: _,
                    r#type: IqType::Error,
                    lang: _,
                    query: None,
                    errors,
                } if id == iq_id => {
                    return Err(Error::ClientError(
                        errors.first().ok_or(Error::MissingError)?.clone(),
                    ))
                }
                _ => return Err(Error::UnexpectedElement(result.into_element())),
            }
        }
    }

    #[instrument]
    pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
        // client to server
        let (reader, writer) = tokio::io::split(connection);
        let mut reader = JabberReader::new(Reader::new(reader));
        let mut writer = JabberWriter::new(Writer::new(writer));

        // declaration
        writer.write_declaration(XML_VERSION).await?;

        // opening stream element
        let stream = Stream::new_client(
            None,
            JID::from_str(server.as_ref())?,
            None,
            "en".to_string(),
        );
        writer.write_start(&stream).await?;

        // server to client

        // may or may not send a declaration
        let _decl = reader.read_prolog().await?;

        // receive stream element and validate
        let stream: Stream = reader.read_start().await?;
        debug!("got stream: {:?}", stream);
        if let Some(from) = stream.from {
            *server = from.to_string();
        }

        Ok(Self { reader, writer })
    }

    #[instrument]
    pub async fn get_features(mut self) -> Result<(Features, Self)> {
        debug!("getting features");
        let features: Features = self.reader.read().await?;
        debug!("got features: {:?}", features);
        Ok((features, self))
    }

    pub fn into_inner(self) -> S {
        self.reader
            .into_inner()
            .into_inner()
            .unsplit(self.writer.into_inner().into_inner())
    }

    pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
        self.writer.write(stanza).await?;
        Ok(())
    }
}

impl JabberStream<Unencrypted> {
    #[instrument]
    pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
        self.writer
            .write_full(&StartTls { required: false })
            .await?;
        let proceed: Proceed = self.reader.read().await?;
        debug!("got proceed: {:?}", proceed);
        let connector = TlsConnector::new().unwrap();
        let stream = self
            .reader
            .into_inner()
            .into_inner()
            .unsplit(self.writer.into_inner().into_inner());
        if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
            .connect(domain.as_ref(), stream)
            .await
        {
            return Ok(tls_stream);
        } else {
            return Err(Error::Connection);
        }
    }
}

impl std::fmt::Debug for JabberStream<Tls> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"tls")
            .finish()
    }
}

impl std::fmt::Debug for JabberStream<Unencrypted> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"unencrypted")
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use test_log::test;

    #[test(tokio::test)]
    async fn start_stream() {
        // let connection = Connection::connect("blos.sm", None, None).await.unwrap();
        // match connection {
        //     Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
        //     Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
        // }
    }

    #[test(tokio::test)]
    async fn sasl() {
        // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
        //     .await
        //     .unwrap()
        //     .ensure_tls()
        //     .await
        //     .unwrap();
        // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        // println!("data: {}", text);
        // jabber.start_stream().await.unwrap();

        // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        // println!("data: {}", text);
        // jabber.reader.read_buf().await.unwrap();
        // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
        // println!("data: {}", text);

        // let features = jabber.get_features().await.unwrap();
        // let (sasl_config, feature) = (
        //     jabber.auth.clone().unwrap(),
        //     features
        //         .features
        //         .iter()
        //         .find(|feature| matches!(feature, Feature::Sasl(_)))
        //         .unwrap(),
        // );
        // match feature {
        //     Feature::StartTls(_start_tls) => todo!(),
        //     Feature::Sasl(mechanisms) => {
        //         jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
        //     }
        //     Feature::Bind => todo!(),
        //     Feature::Unknown => todo!(),
        // }
    }

    #[tokio::test]
    async fn sink() {
        // let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        // client.connect().await.unwrap();
        // let stream = client.inner().unwrap();
        // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move {
        //     stream.writer.write(&stanza).await?;
        //     Ok::<JabberStream<Tls>, Error>(stream)
        // });
        // todo!()
        // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
        //     .await
        //     .unwrap()
        //     .ensure_tls()
        //     .await
        //     .unwrap()
        //     .negotiate()
        //     .await
        //     .unwrap();
        // sleep(Duration::from_secs(5)).await
    }
}