From 322b2a3b46348ec1c5acbc538de93310c9030b96 Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Wed, 12 Jul 2023 21:11:20 +0100 Subject: reimplement sasl (with SCRAM!) --- src/stanza/sasl.rs | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 159 insertions(+), 4 deletions(-) (limited to 'src/stanza/sasl.rs') diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 1f77ffa..bbf3f41 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1,8 +1,163 @@ -pub struct Auth { - pub mechanism: String, - pub sasl_data: Option, +use quick_xml::{ + events::{BytesStart, BytesText, Event}, + name::QName, +}; + +use crate::error::SASLError; +use crate::JabberError; + +use super::Element; + +const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug)] +pub struct Auth<'e> { + pub mechanism: &'e str, + pub sasl_data: &'e str, +} + +impl<'e> Auth<'e> { + fn event(&self) -> Event<'e> { + let mut start = BytesStart::new("auth"); + start.push_attribute(("xmlns", XMLNS)); + start.push_attribute(("mechanism", self.mechanism)); + Event::Start(start) + } + + fn children(&self) -> Option>> { + let sasl = BytesText::from_escaped(self.sasl_data); + let sasl = Element { + event: Event::Text(sasl), + children: None, + }; + Some(vec![sasl]) + } } +impl<'e> Into> for Auth<'e> { + fn into(self) -> Element<'e> { + Element { + event: self.event(), + children: self.children(), + } + } +} + +#[derive(Debug)] pub struct Challenge { - pub sasl_data: String, + pub sasl_data: Vec, +} + +impl<'e> TryFrom<&Element<'e>> for Challenge { + type Error = JabberError; + + fn try_from(element: &Element<'e>) -> Result { + if let Event::Start(start) = &element.event { + if start.name() == QName(b"challenge") { + let sasl_data: &Element<'_> = super::child(element)?; + if let Event::Text(sasl_data) = &sasl_data.event { + let s = sasl_data.clone(); + let s = s.into_inner(); + let s = s.to_vec(); + return Ok(Challenge { sasl_data: s }); + } + } + } + Err(SASLError::NoChallenge.into()) + } +} + +// impl<'e> TryFrom> for Challenge { +// type Error = JabberError; + +// fn try_from(element: Element<'e>) -> Result { +// if let Event::Start(start) = &element.event { +// if start.name() == QName(b"challenge") { +// println!("one"); +// if let Some(children) = element.children.as_deref() { +// if children.len() == 1 { +// let sasl_data = children.first().unwrap(); +// if let Event::Text(sasl_data) = &sasl_data.event { +// return Ok(Challenge { +// sasl_data: sasl_data.clone().into_inner().to_vec(), +// }); +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } else { +// return Err(SASLError::NoChallenge.into()); +// } +// } +// } +// Err(SASLError::NoChallenge.into()) +// } +// } + +#[derive(Debug)] +pub struct Response<'e> { + pub sasl_data: &'e str, +} + +impl<'e> Response<'e> { + fn event(&self) -> Event<'e> { + let mut start = BytesStart::new("response"); + start.push_attribute(("xmlns", XMLNS)); + Event::Start(start) + } + + fn children(&self) -> Option>> { + let sasl = BytesText::from_escaped(self.sasl_data); + let sasl = Element { + event: Event::Text(sasl), + children: None, + }; + Some(vec![sasl]) + } +} + +impl<'e> Into> for Response<'e> { + fn into(self) -> Element<'e> { + Element { + event: self.event(), + children: self.children(), + } + } +} + +#[derive(Debug)] +pub struct Success { + pub sasl_data: Option>, +} + +impl<'e> TryFrom<&Element<'e>> for Success { + type Error = JabberError; + + fn try_from(element: &Element<'e>) -> Result { + match &element.event { + Event::Start(start) => { + if start.name() == QName(b"success") { + match super::child(element) { + Ok(sasl_data) => { + if let Event::Text(sasl_data) = &sasl_data.event { + return Ok(Success { + sasl_data: Some(sasl_data.clone().into_inner().to_vec()), + }); + } + } + Err(_) => return Ok(Success { sasl_data: None }), + }; + } + } + Event::Empty(empty) => { + if empty.name() == QName(b"success") { + return Ok(Success { sasl_data: None }); + } + } + _ => {} + } + Err(SASLError::NoSuccess.into()) + } } -- cgit