diff options
Diffstat (limited to '')
| -rw-r--r-- | src/client/encrypted.rs | 130 | ||||
| -rw-r--r-- | src/client/mod.rs | 11 | ||||
| -rw-r--r-- | src/client/unencrypted.rs | 8 | ||||
| -rw-r--r-- | src/error.rs | 13 | ||||
| -rw-r--r-- | src/jabber.rs | 2 | ||||
| -rw-r--r-- | src/stanza/mod.rs | 74 | ||||
| -rw-r--r-- | src/stanza/sasl.rs | 163 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 20 | 
8 files changed, 353 insertions, 68 deletions
| diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index 898dc23..e8b7271 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -1,13 +1,23 @@ +use std::str; +  use quick_xml::{      events::{BytesDecl, Event}, +    name::QName,      Reader, Writer,  }; +use rsasl::prelude::{Mechname, SASLClient};  use tokio::io::{BufReader, ReadHalf, WriteHalf};  use tokio::net::TcpStream;  use tokio_native_tls::TlsStream; -use crate::stanza::stream::{Stream, StreamFeature}; -use crate::stanza::Element; +use crate::stanza::{ +    sasl::{Auth, Response}, +    stream::{Stream, StreamFeature}, +}; +use crate::stanza::{ +    sasl::{Challenge, Success}, +    Element, +};  use crate::Jabber;  use crate::Result; @@ -48,27 +58,111 @@ impl<'j> JabberClient<'j> {          Ok(())      } -    pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> { -        if let Some(features) = Element::read(&mut self.reader).await? { -            Ok(Some(features.try_into()?)) -        } else { -            Ok(None) -        } +    pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { +        Element::read(&mut self.reader).await?.try_into()      }      pub async fn negotiate(&mut self) -> Result<()> {          loop {              println!("loop"); -            let features = &self.get_features().await?; -            println!("{:?}", features); -            // match &features[0] { -            //     StreamFeature::Sasl(sasl) => { -            //         println!("{:?}", sasl); -            //         todo!() -            //     } -            //     StreamFeature::Bind => todo!(), -            //     x => println!("{:?}", x), -            // } +            let features = self.get_features().await?; +            println!("features: {:?}", features); +            match &features[0] { +                StreamFeature::Sasl(sasl) => { +                    println!("sasl?"); +                    self.sasl(&sasl).await?; +                } +                StreamFeature::Bind => todo!(), +                x => println!("{:?}", x), +            } +        } +    } + +    pub async fn sasl(&mut self, mechanisms: &Vec<String>) -> Result<()> { +        println!("{:?}", mechanisms); +        let sasl = SASLClient::new(self.jabber.auth.clone()); +        let mut offered_mechs: Vec<&Mechname> = Vec::new(); +        for mechanism in mechanisms { +            offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)          } +        println!("{:?}", offered_mechs); +        let mut session = sasl.start_suggested(&offered_mechs)?; +        let selected_mechanism = session.get_mechname().as_str().to_owned(); +        println!("selected mech: {:?}", selected_mechanism); +        let mut data: Option<Vec<u8>> = None; +        if !session.are_we_first() { +            // if not first mention the mechanism then get challenge data +            // mention mechanism +            let auth = Auth { +                mechanism: selected_mechanism.as_str(), +                sasl_data: "=", +            }; +            Into::<Element>::into(auth).write(&mut self.writer).await?; +            // get challenge data +            let challenge = &Element::read(&mut self.reader).await?; +            let challenge: Challenge = challenge.try_into()?; +            println!("challenge: {:?}", challenge); +            data = Some(challenge.sasl_data.to_owned()); +            println!("we didn't go first"); +        } else { +            // if first, mention mechanism and send data +            let mut sasl_data = Vec::new(); +            session.step64(None, &mut sasl_data).unwrap(); +            let auth = Auth { +                mechanism: selected_mechanism.as_str(), +                sasl_data: str::from_utf8(&sasl_data)?, +            }; +            println!("{:?}", auth); +            Into::<Element>::into(auth).write(&mut self.writer).await?; + +            let server_response = Element::read(&mut self.reader).await?; +            println!("server_response: {:#?}", server_response); +            match TryInto::<Challenge>::try_into(&server_response) { +                Ok(challenge) => data = Some(challenge.sasl_data.to_owned()), +                Err(_) => { +                    let success = TryInto::<Success>::try_into(&server_response)?; +                    if let Some(sasl_data) = success.sasl_data { +                        data = Some(sasl_data.to_owned()) +                    } +                } +            } +            println!("we went first"); +        } + +        // stepping the authentication exchange to completion +        if data != None { +            println!("data: {:?}", data); +            let mut sasl_data = Vec::new(); +            while { +                // decide if need to send more data over +                let state = session +                    .step64(data.as_deref(), &mut sasl_data) +                    .expect("step errored!"); +                state.is_running() +            } { +                // While we aren't finished, receive more data from the other party +                let response = Response { +                    sasl_data: str::from_utf8(&sasl_data)?, +                }; +                println!("response: {:?}", response); +                Into::<Element>::into(response) +                    .write(&mut self.writer) +                    .await?; + +                let server_response = Element::read(&mut self.reader).await?; +                println!("server_response: {:?}", server_response); +                match TryInto::<Challenge>::try_into(&server_response) { +                    Ok(challenge) => data = Some(challenge.sasl_data.to_owned()), +                    Err(_) => { +                        let success = TryInto::<Success>::try_into(&server_response)?; +                        if let Some(sasl_data) = success.sasl_data { +                            data = Some(sasl_data.to_owned()) +                        } +                    } +                } +            } +        } +        self.start_stream().await?; +        Ok(())      }  } diff --git a/src/client/mod.rs b/src/client/mod.rs index d545923..280e0a1 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -17,14 +17,11 @@ impl<'j> JabberClientType<'j> {          match self {              Self::Encrypted(c) => Ok(c),              Self::Unencrypted(mut c) => { -                if let Some(features) = c.get_features().await? { -                    if features.contains(&StreamFeature::StartTls) { -                        Ok(c.starttls().await?) -                    } else { -                        Err(JabberError::StartTlsUnavailable) -                    } +                let features = c.get_features().await?; +                if features.contains(&StreamFeature::StartTls) { +                    Ok(c.starttls().await?)                  } else { -                    Err(JabberError::NoFeatures) +                    Err(JabberError::StartTlsUnavailable)                  }              }          } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs index dcd10c6..27b0a5f 100644 --- a/src/client/unencrypted.rs +++ b/src/client/unencrypted.rs @@ -50,12 +50,8 @@ impl<'j> JabberClient<'j> {          Ok(())      } -    pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> { -        if let Some(features) = Element::read(&mut self.reader).await? { -            Ok(Some(features.try_into()?)) -        } else { -            Ok(None) -        } +    pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { +        Element::read(&mut self.reader).await?.try_into()      }      pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> { diff --git a/src/error.rs b/src/error.rs index 7f704e5..17bfbef 100644 --- a/src/error.rs +++ b/src/error.rs @@ -18,6 +18,7 @@ pub enum JabberError {      NoFeatures,      UnknownNamespace,      ParseError, +    UnexpectedEnd,      XML(quick_xml::Error),      SASL(SASLError),      Element(ElementError<'static>), @@ -28,6 +29,8 @@ pub enum JabberError {  pub enum SASLError {      SASL(rsasl::prelude::SASLError),      MechanismName(MechanismNameError), +    NoChallenge, +    NoSuccess,  }  impl From<rsasl::prelude::SASLError> for JabberError { @@ -37,8 +40,14 @@ impl From<rsasl::prelude::SASLError> for JabberError {  }  impl From<MechanismNameError> for JabberError { -    fn from(value: MechanismNameError) -> Self { -        Self::SASL(SASLError::MechanismName(value)) +    fn from(e: MechanismNameError) -> Self { +        Self::SASL(SASLError::MechanismName(e)) +    } +} + +impl From<SASLError> for JabberError { +    fn from(e: SASLError) -> Self { +        Self::SASL(e)      }  } diff --git a/src/jabber.rs b/src/jabber.rs index a48751c..1a7eddb 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -24,7 +24,7 @@ pub struct Jabber<'j> {  impl<'j> Jabber<'j> {      pub fn new(jid: JID, password: String) -> Result<Self> {          let server = jid.domainpart.clone(); -        let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?; +        let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;          println!("auth: {:?}", auth);          Ok(Self {              jid, diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index 16f3bdd..c29b1a2 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -9,12 +9,12 @@ use quick_xml::events::Event;  use quick_xml::{Reader, Writer};  use tokio::io::{AsyncBufRead, AsyncWrite}; -use crate::Result; +use crate::JabberError; -#[derive(Debug)] +#[derive(Clone, Debug)]  pub struct Element<'e> {      pub event: Event<'e>, -    pub content: Option<Vec<Element<'e>>>, +    pub children: Option<Vec<Element<'e>>>,  }  impl<'e: 'async_recursion, 'async_recursion> Element<'e> { @@ -23,7 +23,7 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {          writer: &'life0 mut Writer<W>,      ) -> ::core::pin::Pin<          Box< -            dyn ::core::future::Future<Output = Result<()>> +            dyn ::core::future::Future<Output = Result<(), JabberError>>                  + 'async_recursion                  + ::core::marker::Send,          >, @@ -36,9 +36,9 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {              match &self.event {                  Event::Start(e) => {                      writer.write_event_async(Event::Start(e.clone())).await?; -                    if let Some(content) = &self.content { -                        for _e in content { -                            self.write(writer).await?; +                    if let Some(children) = &self.children { +                        for e in children { +                            e.write(writer).await?;                          }                      }                      writer.write_event_async(Event::End(e.to_end())).await?; @@ -54,7 +54,7 @@ impl<'e> Element<'e> {      pub async fn write_start<W: AsyncWrite + Unpin + Send>(          &self,          writer: &mut Writer<W>, -    ) -> Result<()> { +    ) -> Result<(), JabberError> {          match self.event.as_ref() {              Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?),              e => Err(ElementError::NotAStart(e.clone().into_owned()).into()), @@ -64,7 +64,7 @@ impl<'e> Element<'e> {      pub async fn write_end<W: AsyncWrite + Unpin + Send>(          &self,          writer: &mut Writer<W>, -    ) -> Result<()> { +    ) -> Result<(), JabberError> {          match self.event.as_ref() {              Event::Start(e) => Ok(writer                  .write_event_async(Event::End(e.clone().to_end())) @@ -76,28 +76,38 @@ impl<'e> Element<'e> {      #[async_recursion]      pub async fn read<R: AsyncBufRead + Unpin + Send>(          reader: &mut Reader<R>, -    ) -> Result<Option<Self>> { +    ) -> Result<Self, JabberError> { +        let element = Self::read_recursive(reader) +            .await? +            .ok_or(JabberError::UnexpectedEnd); +        element +    } + +    #[async_recursion] +    async fn read_recursive<R: AsyncBufRead + Unpin + Send>( +        reader: &mut Reader<R>, +    ) -> Result<Option<Self>, JabberError> {          let mut buf = Vec::new();          let event = reader.read_event_into_async(&mut buf).await?;          match event {              Event::Start(e) => { -                let mut content_vec = Vec::new(); -                while let Some(sub_element) = Element::read(reader).await? { -                    content_vec.push(sub_element) +                let mut children_vec = Vec::new(); +                while let Some(sub_element) = Element::read_recursive(reader).await? { +                    children_vec.push(sub_element)                  } -                let mut content = None; -                if !content_vec.is_empty() { -                    content = Some(content_vec) +                let mut children = None; +                if !children_vec.is_empty() { +                    children = Some(children_vec)                  }                  Ok(Some(Self {                      event: Event::Start(e.into_owned()), -                    content, +                    children,                  }))              }              Event::End(_) => Ok(None),              e => Ok(Some(Self {                  event: e.into_owned(), -                content: None, +                children: None,              })),          }      } @@ -105,14 +115,14 @@ impl<'e> Element<'e> {      #[async_recursion]      pub async fn read_start<R: AsyncBufRead + Unpin + Send>(          reader: &mut Reader<R>, -    ) -> Result<Self> { +    ) -> Result<Self, JabberError> {          let mut buf = Vec::new();          let event = reader.read_event_into_async(&mut buf).await?;          match event {              Event::Start(e) => {                  return Ok(Self {                      event: Event::Start(e.into_owned()), -                    content: None, +                    children: None,                  })              }              e => Err(ElementError::NotAStart(e.into_owned()).into()), @@ -120,7 +130,31 @@ impl<'e> Element<'e> {      }  } +/// if there is only one child in the vec of children, will return that element +pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> { +    if let Some(children) = &element.children { +        if children.len() == 1 { +            return Ok(&children[0]); +        } else { +            return Err(ElementError::MultipleChildren); +        } +    } +    Err(ElementError::NoChildren) +} + +/// returns reference to children +pub fn children<'p, 'e>( +    element: &'p Element<'e>, +) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> { +    if let Some(children) = &element.children { +        return Ok(children); +    } +    Err(ElementError::NoChildren) +} +  #[derive(Debug)]  pub enum ElementError<'e> {      NotAStart(Event<'e>), +    NoChildren, +    MultipleChildren,  } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 1f77ffa..bbf3f41 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1,8 +1,163 @@ -pub struct Auth { -    pub mechanism: String, -    pub sasl_data: Option<String>, +use quick_xml::{ +    events::{BytesStart, BytesText, Event}, +    name::QName, +}; + +use crate::error::SASLError; +use crate::JabberError; + +use super::Element; + +const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug)] +pub struct Auth<'e> { +    pub mechanism: &'e str, +    pub sasl_data: &'e str, +} + +impl<'e> Auth<'e> { +    fn event(&self) -> Event<'e> { +        let mut start = BytesStart::new("auth"); +        start.push_attribute(("xmlns", XMLNS)); +        start.push_attribute(("mechanism", self.mechanism)); +        Event::Start(start) +    } + +    fn children(&self) -> Option<Vec<Element<'e>>> { +        let sasl = BytesText::from_escaped(self.sasl_data); +        let sasl = Element { +            event: Event::Text(sasl), +            children: None, +        }; +        Some(vec![sasl]) +    }  } +impl<'e> Into<Element<'e>> for Auth<'e> { +    fn into(self) -> Element<'e> { +        Element { +            event: self.event(), +            children: self.children(), +        } +    } +} + +#[derive(Debug)]  pub struct Challenge { -    pub sasl_data: String, +    pub sasl_data: Vec<u8>, +} + +impl<'e> TryFrom<&Element<'e>> for Challenge { +    type Error = JabberError; + +    fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> { +        if let Event::Start(start) = &element.event { +            if start.name() == QName(b"challenge") { +                let sasl_data: &Element<'_> = super::child(element)?; +                if let Event::Text(sasl_data) = &sasl_data.event { +                    let s = sasl_data.clone(); +                    let s = s.into_inner(); +                    let s = s.to_vec(); +                    return Ok(Challenge { sasl_data: s }); +                } +            } +        } +        Err(SASLError::NoChallenge.into()) +    } +} + +// impl<'e> TryFrom<Element<'e>> for Challenge { +//     type Error = JabberError; + +//     fn try_from(element: Element<'e>) -> Result<Challenge, Self::Error> { +//         if let Event::Start(start) = &element.event { +//             if start.name() == QName(b"challenge") { +//                 println!("one"); +//                 if let Some(children) = element.children.as_deref() { +//                     if children.len() == 1 { +//                         let sasl_data = children.first().unwrap(); +//                         if let Event::Text(sasl_data) = &sasl_data.event { +//                             return Ok(Challenge { +//                                 sasl_data: sasl_data.clone().into_inner().to_vec(), +//                             }); +//                         } else { +//                             return Err(SASLError::NoChallenge.into()); +//                         } +//                     } else { +//                         return Err(SASLError::NoChallenge.into()); +//                     } +//                 } else { +//                     return Err(SASLError::NoChallenge.into()); +//                 } +//             } +//         } +//         Err(SASLError::NoChallenge.into()) +//     } +// } + +#[derive(Debug)] +pub struct Response<'e> { +    pub sasl_data: &'e str, +} + +impl<'e> Response<'e> { +    fn event(&self) -> Event<'e> { +        let mut start = BytesStart::new("response"); +        start.push_attribute(("xmlns", XMLNS)); +        Event::Start(start) +    } + +    fn children(&self) -> Option<Vec<Element<'e>>> { +        let sasl = BytesText::from_escaped(self.sasl_data); +        let sasl = Element { +            event: Event::Text(sasl), +            children: None, +        }; +        Some(vec![sasl]) +    } +} + +impl<'e> Into<Element<'e>> for Response<'e> { +    fn into(self) -> Element<'e> { +        Element { +            event: self.event(), +            children: self.children(), +        } +    } +} + +#[derive(Debug)] +pub struct Success { +    pub sasl_data: Option<Vec<u8>>, +} + +impl<'e> TryFrom<&Element<'e>> for Success { +    type Error = JabberError; + +    fn try_from(element: &Element<'e>) -> Result<Success, Self::Error> { +        match &element.event { +            Event::Start(start) => { +                if start.name() == QName(b"success") { +                    match super::child(element) { +                        Ok(sasl_data) => { +                            if let Event::Text(sasl_data) = &sasl_data.event { +                                return Ok(Success { +                                    sasl_data: Some(sasl_data.clone().into_inner().to_vec()), +                                }); +                            } +                        } +                        Err(_) => return Ok(Success { sasl_data: None }), +                    }; +                } +            } +            Event::Empty(empty) => { +                if empty.name() == QName(b"success") { +                    return Ok(Success { sasl_data: None }); +                } +            } +            _ => {} +        } +        Err(SASLError::NoSuccess.into()) +    }  } diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 32f449d..66741b8 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -58,7 +58,7 @@ impl Stream {          }      } -    fn build(&self) -> BytesStart { +    fn event(&self) -> Event<'static> {          let mut start = BytesStart::new("stream:stream");          if let Some(from) = &self.from {              start.push_attribute(("from", from.to_string().as_str())); @@ -80,15 +80,15 @@ impl Stream {              XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())),          }          start.push_attribute(("xmlns:stream", XMLNS_STREAM)); -        start +        Event::Start(start)      }  }  impl<'e> Into<Element<'e>> for Stream {      fn into(self) -> Element<'e> {          Element { -            event: Event::Start(self.build().to_owned()), -            content: None, +            event: self.event(), +            children: None,          }      }  } @@ -153,17 +153,17 @@ impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {      fn try_from(features_element: Element) -> Result<Self> {          let mut features = Vec::new(); -        if let Some(content) = features_element.content { -            for feature_element in content { +        if let Some(children) = features_element.children { +            for feature_element in children {                  match feature_element.event {                      Event::Start(e) => match e.name() {                          QName(b"starttls") => features.push(StreamFeature::StartTls),                          QName(b"mechanisms") => {                              let mut mechanisms = Vec::new(); -                            if let Some(content) = feature_element.content { -                                for mechanism_element in content { -                                    if let Some(content) = mechanism_element.content { -                                        for mechanism_text in content { +                            if let Some(children) = feature_element.children { +                                for mechanism_element in children { +                                    if let Some(children) = mechanism_element.children { +                                        for mechanism_text in children {                                              match mechanism_text.event {                                                  Event::Text(e) => mechanisms                                                      .push(str::from_utf8(e.as_ref())?.to_owned()), | 
