diff options
| author | 2023-07-12 21:11:20 +0100 | |
|---|---|---|
| committer | 2023-07-12 21:11:20 +0100 | |
| commit | 322b2a3b46348ec1c5acbc538de93310c9030b96 (patch) | |
| tree | e447920e2414c4d3d99ce021785f0fe8103d378a /src/stanza | |
| parent | c9683935f1e94a701be3e6efe0634dbc63c861de (diff) | |
| download | luz-322b2a3b46348ec1c5acbc538de93310c9030b96.tar.gz luz-322b2a3b46348ec1c5acbc538de93310c9030b96.tar.bz2 luz-322b2a3b46348ec1c5acbc538de93310c9030b96.zip | |
reimplement sasl (with SCRAM!)
Diffstat (limited to '')
| -rw-r--r-- | src/stanza/mod.rs | 74 | ||||
| -rw-r--r-- | src/stanza/sasl.rs | 163 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 20 | 
3 files changed, 223 insertions, 34 deletions
| diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index 16f3bdd..c29b1a2 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -9,12 +9,12 @@ use quick_xml::events::Event;  use quick_xml::{Reader, Writer};  use tokio::io::{AsyncBufRead, AsyncWrite}; -use crate::Result; +use crate::JabberError; -#[derive(Debug)] +#[derive(Clone, Debug)]  pub struct Element<'e> {      pub event: Event<'e>, -    pub content: Option<Vec<Element<'e>>>, +    pub children: Option<Vec<Element<'e>>>,  }  impl<'e: 'async_recursion, 'async_recursion> Element<'e> { @@ -23,7 +23,7 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {          writer: &'life0 mut Writer<W>,      ) -> ::core::pin::Pin<          Box< -            dyn ::core::future::Future<Output = Result<()>> +            dyn ::core::future::Future<Output = Result<(), JabberError>>                  + 'async_recursion                  + ::core::marker::Send,          >, @@ -36,9 +36,9 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {              match &self.event {                  Event::Start(e) => {                      writer.write_event_async(Event::Start(e.clone())).await?; -                    if let Some(content) = &self.content { -                        for _e in content { -                            self.write(writer).await?; +                    if let Some(children) = &self.children { +                        for e in children { +                            e.write(writer).await?;                          }                      }                      writer.write_event_async(Event::End(e.to_end())).await?; @@ -54,7 +54,7 @@ impl<'e> Element<'e> {      pub async fn write_start<W: AsyncWrite + Unpin + Send>(          &self,          writer: &mut Writer<W>, -    ) -> Result<()> { +    ) -> Result<(), JabberError> {          match self.event.as_ref() {              Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?),              e => Err(ElementError::NotAStart(e.clone().into_owned()).into()), @@ -64,7 +64,7 @@ impl<'e> Element<'e> {      pub async fn write_end<W: AsyncWrite + Unpin + Send>(          &self,          writer: &mut Writer<W>, -    ) -> Result<()> { +    ) -> Result<(), JabberError> {          match self.event.as_ref() {              Event::Start(e) => Ok(writer                  .write_event_async(Event::End(e.clone().to_end())) @@ -76,28 +76,38 @@ impl<'e> Element<'e> {      #[async_recursion]      pub async fn read<R: AsyncBufRead + Unpin + Send>(          reader: &mut Reader<R>, -    ) -> Result<Option<Self>> { +    ) -> Result<Self, JabberError> { +        let element = Self::read_recursive(reader) +            .await? +            .ok_or(JabberError::UnexpectedEnd); +        element +    } + +    #[async_recursion] +    async fn read_recursive<R: AsyncBufRead + Unpin + Send>( +        reader: &mut Reader<R>, +    ) -> Result<Option<Self>, JabberError> {          let mut buf = Vec::new();          let event = reader.read_event_into_async(&mut buf).await?;          match event {              Event::Start(e) => { -                let mut content_vec = Vec::new(); -                while let Some(sub_element) = Element::read(reader).await? { -                    content_vec.push(sub_element) +                let mut children_vec = Vec::new(); +                while let Some(sub_element) = Element::read_recursive(reader).await? { +                    children_vec.push(sub_element)                  } -                let mut content = None; -                if !content_vec.is_empty() { -                    content = Some(content_vec) +                let mut children = None; +                if !children_vec.is_empty() { +                    children = Some(children_vec)                  }                  Ok(Some(Self {                      event: Event::Start(e.into_owned()), -                    content, +                    children,                  }))              }              Event::End(_) => Ok(None),              e => Ok(Some(Self {                  event: e.into_owned(), -                content: None, +                children: None,              })),          }      } @@ -105,14 +115,14 @@ impl<'e> Element<'e> {      #[async_recursion]      pub async fn read_start<R: AsyncBufRead + Unpin + Send>(          reader: &mut Reader<R>, -    ) -> Result<Self> { +    ) -> Result<Self, JabberError> {          let mut buf = Vec::new();          let event = reader.read_event_into_async(&mut buf).await?;          match event {              Event::Start(e) => {                  return Ok(Self {                      event: Event::Start(e.into_owned()), -                    content: None, +                    children: None,                  })              }              e => Err(ElementError::NotAStart(e.into_owned()).into()), @@ -120,7 +130,31 @@ impl<'e> Element<'e> {      }  } +/// if there is only one child in the vec of children, will return that element +pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> { +    if let Some(children) = &element.children { +        if children.len() == 1 { +            return Ok(&children[0]); +        } else { +            return Err(ElementError::MultipleChildren); +        } +    } +    Err(ElementError::NoChildren) +} + +/// returns reference to children +pub fn children<'p, 'e>( +    element: &'p Element<'e>, +) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> { +    if let Some(children) = &element.children { +        return Ok(children); +    } +    Err(ElementError::NoChildren) +} +  #[derive(Debug)]  pub enum ElementError<'e> {      NotAStart(Event<'e>), +    NoChildren, +    MultipleChildren,  } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 1f77ffa..bbf3f41 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1,8 +1,163 @@ -pub struct Auth { -    pub mechanism: String, -    pub sasl_data: Option<String>, +use quick_xml::{ +    events::{BytesStart, BytesText, Event}, +    name::QName, +}; + +use crate::error::SASLError; +use crate::JabberError; + +use super::Element; + +const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug)] +pub struct Auth<'e> { +    pub mechanism: &'e str, +    pub sasl_data: &'e str, +} + +impl<'e> Auth<'e> { +    fn event(&self) -> Event<'e> { +        let mut start = BytesStart::new("auth"); +        start.push_attribute(("xmlns", XMLNS)); +        start.push_attribute(("mechanism", self.mechanism)); +        Event::Start(start) +    } + +    fn children(&self) -> Option<Vec<Element<'e>>> { +        let sasl = BytesText::from_escaped(self.sasl_data); +        let sasl = Element { +            event: Event::Text(sasl), +            children: None, +        }; +        Some(vec![sasl]) +    }  } +impl<'e> Into<Element<'e>> for Auth<'e> { +    fn into(self) -> Element<'e> { +        Element { +            event: self.event(), +            children: self.children(), +        } +    } +} + +#[derive(Debug)]  pub struct Challenge { -    pub sasl_data: String, +    pub sasl_data: Vec<u8>, +} + +impl<'e> TryFrom<&Element<'e>> for Challenge { +    type Error = JabberError; + +    fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> { +        if let Event::Start(start) = &element.event { +            if start.name() == QName(b"challenge") { +                let sasl_data: &Element<'_> = super::child(element)?; +                if let Event::Text(sasl_data) = &sasl_data.event { +                    let s = sasl_data.clone(); +                    let s = s.into_inner(); +                    let s = s.to_vec(); +                    return Ok(Challenge { sasl_data: s }); +                } +            } +        } +        Err(SASLError::NoChallenge.into()) +    } +} + +// impl<'e> TryFrom<Element<'e>> for Challenge { +//     type Error = JabberError; + +//     fn try_from(element: Element<'e>) -> Result<Challenge, Self::Error> { +//         if let Event::Start(start) = &element.event { +//             if start.name() == QName(b"challenge") { +//                 println!("one"); +//                 if let Some(children) = element.children.as_deref() { +//                     if children.len() == 1 { +//                         let sasl_data = children.first().unwrap(); +//                         if let Event::Text(sasl_data) = &sasl_data.event { +//                             return Ok(Challenge { +//                                 sasl_data: sasl_data.clone().into_inner().to_vec(), +//                             }); +//                         } else { +//                             return Err(SASLError::NoChallenge.into()); +//                         } +//                     } else { +//                         return Err(SASLError::NoChallenge.into()); +//                     } +//                 } else { +//                     return Err(SASLError::NoChallenge.into()); +//                 } +//             } +//         } +//         Err(SASLError::NoChallenge.into()) +//     } +// } + +#[derive(Debug)] +pub struct Response<'e> { +    pub sasl_data: &'e str, +} + +impl<'e> Response<'e> { +    fn event(&self) -> Event<'e> { +        let mut start = BytesStart::new("response"); +        start.push_attribute(("xmlns", XMLNS)); +        Event::Start(start) +    } + +    fn children(&self) -> Option<Vec<Element<'e>>> { +        let sasl = BytesText::from_escaped(self.sasl_data); +        let sasl = Element { +            event: Event::Text(sasl), +            children: None, +        }; +        Some(vec![sasl]) +    } +} + +impl<'e> Into<Element<'e>> for Response<'e> { +    fn into(self) -> Element<'e> { +        Element { +            event: self.event(), +            children: self.children(), +        } +    } +} + +#[derive(Debug)] +pub struct Success { +    pub sasl_data: Option<Vec<u8>>, +} + +impl<'e> TryFrom<&Element<'e>> for Success { +    type Error = JabberError; + +    fn try_from(element: &Element<'e>) -> Result<Success, Self::Error> { +        match &element.event { +            Event::Start(start) => { +                if start.name() == QName(b"success") { +                    match super::child(element) { +                        Ok(sasl_data) => { +                            if let Event::Text(sasl_data) = &sasl_data.event { +                                return Ok(Success { +                                    sasl_data: Some(sasl_data.clone().into_inner().to_vec()), +                                }); +                            } +                        } +                        Err(_) => return Ok(Success { sasl_data: None }), +                    }; +                } +            } +            Event::Empty(empty) => { +                if empty.name() == QName(b"success") { +                    return Ok(Success { sasl_data: None }); +                } +            } +            _ => {} +        } +        Err(SASLError::NoSuccess.into()) +    }  } diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 32f449d..66741b8 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -58,7 +58,7 @@ impl Stream {          }      } -    fn build(&self) -> BytesStart { +    fn event(&self) -> Event<'static> {          let mut start = BytesStart::new("stream:stream");          if let Some(from) = &self.from {              start.push_attribute(("from", from.to_string().as_str())); @@ -80,15 +80,15 @@ impl Stream {              XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())),          }          start.push_attribute(("xmlns:stream", XMLNS_STREAM)); -        start +        Event::Start(start)      }  }  impl<'e> Into<Element<'e>> for Stream {      fn into(self) -> Element<'e> {          Element { -            event: Event::Start(self.build().to_owned()), -            content: None, +            event: self.event(), +            children: None,          }      }  } @@ -153,17 +153,17 @@ impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {      fn try_from(features_element: Element) -> Result<Self> {          let mut features = Vec::new(); -        if let Some(content) = features_element.content { -            for feature_element in content { +        if let Some(children) = features_element.children { +            for feature_element in children {                  match feature_element.event {                      Event::Start(e) => match e.name() {                          QName(b"starttls") => features.push(StreamFeature::StartTls),                          QName(b"mechanisms") => {                              let mut mechanisms = Vec::new(); -                            if let Some(content) = feature_element.content { -                                for mechanism_element in content { -                                    if let Some(content) = mechanism_element.content { -                                        for mechanism_text in content { +                            if let Some(children) = feature_element.children { +                                for mechanism_element in children { +                                    if let Some(children) = mechanism_element.children { +                                        for mechanism_text in children {                                              match mechanism_text.event {                                                  Event::Text(e) => mechanisms                                                      .push(str::from_utf8(e.as_ref())?.to_owned()), | 
