summaryrefslogblamecommitdiffstats
path: root/src/jabber.rs
blob: 1436bfaec0f4f65029cbab48197f61b9a73494fa (plain) (tree)
1
2
3
4
5
6
7
8
9
             
                   
 
                                                                 
                               

                                                                                      
                                  
 





                                          
 




                                             
                         


                                  

 











                                                                   


                   
                
                 
                   
         
     
 
 


                                      
 

                                                           

                                                        
 
                      

                                                                
 




                                                                          
                                                                                                  
                                               

                                                                
 
                           
 


                                                                                          
                                              



                                                                          

                                                                                                  




                                      

                                                                                              

                 
                    
         
 
                                              


                                                                             









                                                                                      


                                                       

                 
                                                   
                                             
         

     





                                                                       
                                                             





































                                                                                                





































                                                                              
use std::str;
use std::sync::Arc;

use quick_xml::{events::Event, se::Serializer, NsReader, Writer};
use rsasl::prelude::SASLConfig;
use serde::Serialize;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tracing::{debug, info, trace};

use crate::connection::{Tls, Unencrypted};
use crate::error::JabberError;
use crate::stanza::stream::Stream;
use crate::stanza::DECLARATION;
use crate::Result;
use crate::JID;

pub struct Jabber<S>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    reader: NsReader<BufReader<ReadHalf<S>>>,
    writer: WriteHalf<S>,
    jid: Option<JID>,
    auth: Option<Arc<SASLConfig>>,
    server: String,
}

impl<S> Jabber<S>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    pub fn new(
        reader: ReadHalf<S>,
        writer: WriteHalf<S>,
        jid: Option<JID>,
        auth: Option<Arc<SASLConfig>>,
        server: String,
    ) -> Self {
        let reader = NsReader::from_reader(BufReader::new(reader));
        Self {
            reader,
            writer,
            jid,
            auth,
            server,
        }
    }
}

impl<S> Jabber<S>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    // pub async fn negotiate(self) -> Result<Jabber<S>> {}

    pub async fn start_stream(&mut self) -> Result<()> {
        // client to server

        // declaration
        let mut xmlwriter = Writer::new(&mut self.writer);
        xmlwriter.write_event_async(DECLARATION.clone()).await?;

        // opening stream element
        let server = &self.server.to_owned().try_into()?;
        let stream_element = Stream::new_client(None, server, None, "en");
        // TODO: nicer function to serialize to xml writer
        let mut buffer = String::new();
        let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name");
        stream_element.serialize(ser).unwrap();
        trace!("sent: {}", buffer);
        self.writer.write_all(buffer.as_bytes()).await.unwrap();

        // server to client

        // may or may not send a declaration
        let mut buf = Vec::new();
        let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
        trace!("received: {:?}", first_event);
        match first_event {
            (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => {
                if let Ok(version) = e.version() {
                    if version.as_ref() == b"1.0" {
                        first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
                        trace!("received: {:?}", first_event);
                    } else {
                        // todo: error
                        todo!()
                    }
                } else {
                    first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
                    trace!("received: {:?}", first_event);
                }
            }
            _ => (),
        }

        // receive stream element and validate
        match first_event {
            (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => {
                if ns.0 == crate::stanza::stream::XMLNS.as_bytes() {
                    e.attributes().try_for_each(|attr| -> Result<()> {
                        let attr = attr?;
                        match attr.key.into_inner() {
                            b"from" => {
                                self.server = str::from_utf8(&attr.value)?.to_owned();
                                Ok(())
                            }
                            _ => Ok(()),
                        }
                    });
                    return Ok(());
                } else {
                    return Err(JabberError::BadStream);
                }
            }
            // TODO: errors for incorrect namespace
            _ => Err(JabberError::BadStream),
        }
    }
}

// pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
//     Element::read(&mut self.reader).await?.try_into()
// }

impl Jabber<Unencrypted> {
    pub async fn starttls(&mut self) -> Result<Jabber<Tls>> {
        todo!()
    }
    //     let mut starttls_element = BytesStart::new("starttls");
    //     starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
    //     self.writer
    //         .write_event_async(Event::Empty(starttls_element))
    //         .await
    //         .unwrap();
    //     let mut buf = Vec::new();
    //     match self.reader.read_event_into_async(&mut buf).await.unwrap() {
    //         Event::Empty(e) => match e.name() {
    //             QName(b"proceed") => {
    //                 let connector = TlsConnector::new().unwrap();
    //                 let stream = self
    //                     .reader
    //                     .into_inner()
    //                     .into_inner()
    //                     .unsplit(self.writer.into_inner());
    //                 if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
    //                     .connect(&self.jabber.server, stream)
    //                     .await
    //                 {
    //                     let (read, write) = tokio::io::split(tlsstream);
    //                     let reader = Reader::from_reader(BufReader::new(read));
    //                     let writer = Writer::new(write);
    //                     let mut client =
    //                         super::encrypted::JabberClient::new(reader, writer, self.jabber);
    //                     client.start_stream().await?;
    //                     return Ok(client);
    //                 }
    //             }
    //             QName(_) => return Err(JabberError::TlsNegotiation),
    //         },
    //         _ => return Err(JabberError::TlsNegotiation),
    //     }
    //     Err(JabberError::TlsNegotiation)
    // }
}

impl std::fmt::Debug for Jabber<Tls> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"tls")
            .field("jid", &self.jid)
            .field("auth", &self.auth)
            .field("server", &self.server)
            .finish()
    }
}

impl std::fmt::Debug for Jabber<Unencrypted> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Jabber")
            .field("connection", &"unencrypted")
            .field("jid", &self.jid)
            .field("auth", &self.auth)
            .field("server", &self.server)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::connection::Connection;
    use test_log::test;

    #[test(tokio::test)]
    async fn start_stream() {
        let connection = Connection::connect("blos.sm").await.unwrap();
        match connection {
            Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
            Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
        }
    }
}