summaryrefslogblamecommitdiffstats
path: root/src/client/encrypted.rs
blob: a4bf0d1a601492f0eb29fbea0b4e45bc85ee1964 (plain) (tree)
1
2
3
4
5
6
7
8
9

             
                
                     
                                           

                   

                   


                                                               


                                



                                            




                                                              
                                            





                                                                  
                                                

















                                                                                                   


                                                                     











                                                                              

































































































































































                                                                                  
 
use std::str;

use quick_xml::{
    de::Deserializer,
    events::{BytesDecl, BytesStart, Event},
    name::QName,
    se::Serializer,
    Reader, Writer,
};
use rsasl::prelude::{Mechname, SASLClient};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;

use crate::stanza::{
    sasl::{Auth, Challenge, Mechanisms},
    stream::{StreamFeature, StreamFeatures},
};
use crate::Jabber;
use crate::Result;

pub struct JabberClient<'j> {
    reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
    writer: WriteHalf<TlsStream<TcpStream>>,
    jabber: &'j mut Jabber<'j>,
}

impl<'j> JabberClient<'j> {
    pub fn new(
        reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
        writer: WriteHalf<TlsStream<TcpStream>>,
        jabber: &'j mut Jabber<'j>,
    ) -> Self {
        Self {
            reader,
            writer,
            jabber,
        }
    }

    pub async fn start_stream(&mut self) -> Result<()> {
        let declaration = BytesDecl::new("1.0", None, None);
        let mut stream_element = BytesStart::new("stream:stream");
        stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
        stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
        stream_element.push_attribute(("version", "1.0"));
        stream_element.push_attribute(("xml:lang", "en"));
        stream_element.push_attribute(("xmlns", "jabber:client"));
        stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
        let mut writer = Writer::new(&mut self.writer);
        writer.write_event_async(Event::Decl(declaration)).await;
        writer.write_event_async(Event::Start(stream_element)).await;
        let mut buf = Vec::new();
        loop {
            match self.reader.read_event_into_async(&mut buf).await.unwrap() {
                Event::Start(e) => {
                    println!("{:?}", e);
                    break;
                }
                e => println!("decl: {:?}", e),
            };
        }
        Ok(())
    }

    pub async fn get_node<'a>(&mut self) -> Result<String> {
        let mut buf = Vec::new();
        let mut txt = Vec::new();
        let mut qname_set = false;
        let mut qname: Option<Vec<u8>> = None;
        loop {
            match self.reader.read_event_into_async(&mut buf).await? {
                Event::Start(e) => {
                    if !qname_set {
                        qname = Some(e.name().into_inner().to_owned());
                        qname_set = true;
                    }
                    txt.push(b'<');
                    txt = txt
                        .into_iter()
                        .chain(buf.to_owned())
                        .chain(vec![b'>'])
                        .collect();
                }
                Event::End(e) => {
                    let mut end = false;
                    if e.name() == QName(qname.as_deref().unwrap()) {
                        end = true;
                    }
                    txt.push(b'<');
                    txt = txt
                        .into_iter()
                        .chain(buf.to_owned())
                        .chain(vec![b'>'])
                        .collect();
                    if end {
                        break;
                    }
                }
                Event::Text(_e) => {
                    txt = txt.into_iter().chain(buf.to_owned()).collect();
                }
                _ => {
                    txt.push(b'<');
                    txt = txt
                        .into_iter()
                        .chain(buf.to_owned())
                        .chain(vec![b'>'])
                        .collect();
                }
            }
            buf.clear();
        }
        println!("{:?}", txt);
        let decoded = str::from_utf8(&txt)?.to_owned();
        println!("{:?}", decoded);
        Ok(decoded)
    }

    pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
        let node = self.get_node().await?;
        let mut deserializer = Deserializer::from_str(&node);
        let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
        println!("{:?}", features);
        Ok(features.features)
    }

    pub async fn negotiate(&mut self) -> Result<()> {
        loop {
            println!("loop");
            let features = &self.get_features().await?;
            println!("{:?}", features);
            match &features[0] {
                StreamFeature::Sasl(sasl) => {
                    println!("{:?}", sasl);
                    self.sasl(&sasl).await?;
                }
                StreamFeature::Bind => todo!(),
                x => println!("{:?}", x),
            }
        }
    }

    pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> {
        println!("{:?}", mechanisms);
        let sasl = SASLClient::new(self.jabber.auth.clone());
        let mut offered_mechs: Vec<&Mechname> = Vec::new();
        for mechanism in &mechanisms.mechanisms {
            offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?)
        }
        println!("{:?}", offered_mechs);
        let mut session = sasl.start_suggested(&offered_mechs)?;
        let selected_mechanism = session.get_mechname().as_str().to_owned();
        println!("selected mech: {:?}", selected_mechanism);
        let mut data: Option<Vec<u8>> = None;
        if !session.are_we_first() {
            // if not first mention the mechanism then get challenge data
            // mention mechanism
            let auth = Auth {
                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
                mechanism: selected_mechanism.clone(),
                sasl_data: Some("=".to_owned()),
            };
            let mut buffer = String::new();
            let ser = Serializer::new(&mut buffer);
            auth.serialize(ser).unwrap();
            self.writer.write_all(buffer.as_bytes());
            // get challenge data
            let node = self.get_node().await?;
            let mut deserializer = Deserializer::from_str(&node);
            let challenge = Challenge::deserialize(&mut deserializer).unwrap();
            println!("challenge: {:?}", challenge);
            data = Some(challenge.sasl_data.as_bytes().to_owned());
            println!("we didn't go first");
        } else {
            // if first, mention mechanism and send data
            let mut sasl_data = Vec::new();
            session.step64(None, &mut sasl_data).unwrap();
            let auth = Auth {
                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
                mechanism: selected_mechanism.clone(),
                sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
            };
            let mut buffer = String::new();
            let ser = Serializer::new(&mut buffer);
            auth.serialize(ser).unwrap();
            println!("node: {:?}", buffer);
            self.writer.write_all(buffer.as_bytes()).await;
            println!("we went first");
            // get challenge data
            // TODO: check if needed
            // let node = self.get_node().await?;
            // println!("node: {:?}", node);
            // let mut deserializer = Deserializer::from_str(&node);
            // let challenge = Challenge::deserialize(&mut deserializer).unwrap();
            // println!("challenge: {:?}", challenge);
            // data = Some(challenge.sasl_data.as_bytes().to_owned());
        }

        // stepping the authentication exchange to completion
        let mut sasl_data = Vec::new();
        while {
            // decide if need to send more data over
            let state = session
                .step64(data.as_deref(), &mut sasl_data)
                .expect("step errored!");
            state.is_running()
        } {
            // While we aren't finished, receive more data from the other party
            let auth = Auth {
                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
                mechanism: selected_mechanism.clone(),
                sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
            };
            let mut buffer = String::new();
            let ser = Serializer::new(&mut buffer);
            auth.serialize(ser).unwrap();
            self.writer.write_all(buffer.as_bytes());
            let node = self.get_node().await?;
            let mut deserializer = Deserializer::from_str(&node);
            let challenge = Challenge::deserialize(&mut deserializer).unwrap();
            data = Some(challenge.sasl_data.as_bytes().to_owned());
        }
        self.start_stream().await?;
        Ok(())
    }
}