From cd7bb95c0a31d187bfe25bad15043f0b33b111cf Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Wed, 2 Aug 2023 00:56:38 +0100 Subject: implement resource binding --- Cargo.toml | 1 + src/client/encrypted.rs | 55 +++++++++++++--- src/error.rs | 6 ++ src/lib.rs | 14 ++-- src/stanza/bind.rs | 111 +++++++++++++++++++++++++++++++ src/stanza/iq.rs | 171 ++++++++++++++++++++++++++++++++++++++++++++++++ src/stanza/mod.rs | 46 ++++++++----- src/stanza/sasl.rs | 27 ++------ src/stanza/stream.rs | 6 +- 9 files changed, 385 insertions(+), 52 deletions(-) create mode 100644 src/stanza/bind.rs create mode 100644 src/stanza/iq.rs diff --git a/Cargo.toml b/Cargo.toml index eb89659..49294dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" [dependencies] async-recursion = "1.0.4" async-trait = "0.1.68" +nanoid = "0.4.0" quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] } # TODO: remove unneeded features rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index e8b7271..86aba13 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -11,19 +11,22 @@ use tokio::net::TcpStream; use tokio_native_tls::TlsStream; use crate::stanza::{ - sasl::{Auth, Response}, - stream::{Stream, StreamFeature}, -}; -use crate::stanza::{ + bind::Bind, + iq::IQ, sasl::{Challenge, Success}, Element, }; +use crate::stanza::{ + sasl::{Auth, Response}, + stream::{Stream, StreamFeature}, +}; use crate::Jabber; +use crate::JabberError; use crate::Result; pub struct JabberClient<'j> { - reader: Reader>>>, - writer: Writer>>, + pub reader: Reader>>>, + pub writer: Writer>>, jabber: &'j mut Jabber<'j>, } @@ -64,15 +67,19 @@ impl<'j> JabberClient<'j> { pub async fn negotiate(&mut self) -> Result<()> { loop { - println!("loop"); + println!("negotiate loop"); let features = self.get_features().await?; println!("features: {:?}", features); + match &features[0] { StreamFeature::Sasl(sasl) => { println!("sasl?"); self.sasl(&sasl).await?; } - StreamFeature::Bind => todo!(), + StreamFeature::Bind => { + self.bind().await?; + return Ok(()); + } x => println!("{:?}", x), } } @@ -165,4 +172,36 @@ impl<'j> JabberClient<'j> { self.start_stream().await?; Ok(()) } + + pub async fn bind(&mut self) -> Result<()> { + match &self.jabber.jid.resourcepart { + Some(resource) => { + println!("setting resource"); + let bind = Bind { + resource: Some(resource.clone()), + jid: None, + }; + let result: Bind = IQ::set(self, None, None, bind).await?.try_into()?; + if let Some(jid) = result.jid { + println!("{}", jid); + self.jabber.jid = jid; + return Ok(()); + } + } + None => { + println!("not setting resource"); + let bind = Bind { + resource: None, + jid: None, + }; + let result: Bind = IQ::set(self, None, None, bind).await?.try_into()?; + if let Some(jid) = result.jid { + println!("{}", jid); + self.jabber.jid = jid; + return Ok(()); + } + } + } + Err(JabberError::BindError) + } } diff --git a/src/error.rs b/src/error.rs index 17bfbef..a912840 100644 --- a/src/error.rs +++ b/src/error.rs @@ -17,8 +17,14 @@ pub enum JabberError { Utf8Decode, NoFeatures, UnknownNamespace, + UnknownAttribute, + NoID, + NoType, + IDMismatch, + BindError, ParseError, UnexpectedEnd, + UnexpectedElement, XML(quick_xml::Error), SASL(SASLError), Element(ElementError<'static>), diff --git a/src/lib.rs b/src/lib.rs index d27f0ba..95a228b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![allow(unused_must_use)] +#![feature(let_chains)] // TODO: logging (dropped errors) pub mod client; @@ -44,10 +45,13 @@ mod tests { #[tokio::test] async fn login() { - Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned()) - .unwrap() - .login() - .await - .unwrap(); + Jabber::new( + JID::from_str("test@blos.sm/clown").unwrap(), + "slayed".to_owned(), + ) + .unwrap() + .login() + .await + .unwrap(); } } diff --git a/src/stanza/bind.rs b/src/stanza/bind.rs new file mode 100644 index 0000000..f1bdc2d --- /dev/null +++ b/src/stanza/bind.rs @@ -0,0 +1,111 @@ +use quick_xml::{ + events::{BytesStart, BytesText, Event}, + name::QName, + Reader, +}; + +use super::{Element, IntoElement}; +use crate::{JabberError, JID}; + +const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-bind"; + +pub struct Bind { + pub resource: Option, + pub jid: Option, +} + +impl<'e> IntoElement<'e> for Bind { + fn event(&self) -> quick_xml::events::Event<'static> { + let mut bind_event = BytesStart::new("bind"); + bind_event.push_attribute(("xmlns", XMLNS)); + if self.resource.is_none() && self.jid.is_none() { + return Event::Empty(bind_event); + } else { + return Event::Start(bind_event); + } + } + + fn children(&self) -> Option>> { + if let Some(resource) = &self.resource { + let resource_event: BytesStart<'static> = BytesStart::new("resource"); + let resource_child: BytesText<'static> = BytesText::new(resource).into_owned(); + let resource_child: Element<'static> = Element { + event: Event::Text(resource_child), + children: None, + }; + let resource_element: Element<'static> = Element { + event: Event::Start(resource_event), + children: Some(vec![resource_child]), + }; + return Some(vec![resource_element]); + } else if let Some(jid) = &self.jid { + let jid_event = BytesStart::new("jid"); + let jid_child = BytesText::new(&jid.to_string()).into_owned(); + let jid_child = Element { + event: Event::Text(jid_child), + children: None, + }; + let jid_element = Element { + event: Event::Start(jid_event), + children: Some(vec![jid_child]), + }; + return Some(vec![jid_element]); + } + None + } +} + +impl TryFrom> for Bind { + type Error = JabberError; + + fn try_from(element: Element<'static>) -> Result { + if let Event::Start(start) = &element.event { + let buf: Vec = Vec::new(); + let reader = Reader::from_reader(buf); + if start.name() == QName(b"bind") + && start.try_get_attribute("xmlns")?.is_some_and(|attribute| { + attribute.decode_and_unescape_value(&reader).unwrap() == XMLNS + }) + { + let child: Element<'static> = element.child()?.clone(); + if let Event::Start(start) = &child.event { + match start.name() { + QName(b"resource") => { + let resource_text = child.child()?; + if let Event::Text(text) = &resource_text.event { + return Ok(Self { + resource: Some(text.unescape()?.into_owned()), + jid: None, + }); + } + } + QName(b"jid") => { + let jid_text = child.child()?; + if let Event::Text(text) = &jid_text.event { + return Ok(Self { + jid: Some(text.unescape()?.into_owned().try_into()?), + resource: None, + }); + } + } + _ => return Err(JabberError::UnexpectedElement), + } + } + } + } else if let Event::Empty(start) = &element.event { + let buf: Vec = Vec::new(); + let reader = Reader::from_reader(buf); + if start.name() == QName(b"bind") + && start.try_get_attribute("xmlns")?.is_some_and(|attribute| { + attribute.decode_and_unescape_value(&reader).unwrap() == XMLNS + }) + { + return Ok(Bind { + resource: None, + jid: None, + }); + } + } + Err(JabberError::UnexpectedElement) + } +} diff --git a/src/stanza/iq.rs b/src/stanza/iq.rs new file mode 100644 index 0000000..8a373b2 --- /dev/null +++ b/src/stanza/iq.rs @@ -0,0 +1,171 @@ +use nanoid::nanoid; +use quick_xml::{ + events::{BytesStart, Event}, + name::QName, + Reader, Writer, +}; + +use crate::{JabberClient, JabberError, JID}; + +use super::{Element, IntoElement}; +use crate::Result; + +#[derive(Debug)] +pub struct IQ { + to: Option, + from: Option, + id: String, + r#type: IQType, + lang: Option, + child: Element<'static>, +} + +#[derive(Debug)] +enum IQType { + Get, + Set, + Result, + Error, +} + +impl IQ { + pub async fn set<'j, R: IntoElement<'static>>( + client: &mut JabberClient<'j>, + to: Option, + from: Option, + element: R, + ) -> Result> { + let id = nanoid!(); + let iq = IQ { + to, + from, + id: id.clone(), + r#type: IQType::Set, + lang: None, + child: Element::from(element), + }; + println!("{:?}", iq); + let iq = Element::from(iq); + println!("{:?}", iq); + iq.write(&mut client.writer).await?; + let result = Element::read(&mut client.reader).await?; + let iq = IQ::try_from(result)?; + if iq.id == id { + return Ok(iq.child); + } + Err(JabberError::IDMismatch) + } +} + +impl<'e> IntoElement<'e> for IQ { + fn event(&self) -> quick_xml::events::Event<'e> { + let mut start = BytesStart::new("iq"); + if let Some(to) = &self.to { + start.push_attribute(("to", to.to_string().as_str())); + } + if let Some(from) = &self.from { + start.push_attribute(("from", from.to_string().as_str())); + } + start.push_attribute(("id", self.id.as_str())); + match self.r#type { + IQType::Get => start.push_attribute(("type", "get")), + IQType::Set => start.push_attribute(("type", "set")), + IQType::Result => start.push_attribute(("type", "result")), + IQType::Error => start.push_attribute(("type", "error")), + } + if let Some(lang) = &self.lang { + start.push_attribute(("from", lang.to_string().as_str())); + } + + quick_xml::events::Event::Start(start) + } + + fn children(&self) -> Option>> { + Some(vec![self.child.clone()]) + } +} + +impl TryFrom> for IQ { + type Error = JabberError; + + fn try_from(element: Element<'static>) -> std::result::Result { + if let Event::Start(start) = &element.event { + if start.name() == QName(b"iq") { + let mut to: Option = None; + let mut from: Option = None; + let mut id = None; + let mut r#type = None; + let mut lang = None; + start + .attributes() + .into_iter() + .try_for_each(|attribute| -> Result<()> { + if let Ok(attribute) = attribute { + let buf: Vec = Vec::new(); + let reader = Reader::from_reader(buf); + match attribute.key { + QName(b"to") => { + to = Some( + attribute + .decode_and_unescape_value(&reader) + .or(Err(JabberError::Utf8Decode))? + .into_owned() + .try_into()?, + ) + } + QName(b"from") => { + from = Some( + attribute + .decode_and_unescape_value(&reader) + .or(Err(JabberError::Utf8Decode))? + .into_owned() + .try_into()?, + ) + } + QName(b"id") => { + id = Some( + attribute + .decode_and_unescape_value(&reader) + .or(Err(JabberError::Utf8Decode))? + .into_owned(), + ) + } + QName(b"type") => { + let value = attribute + .decode_and_unescape_value(&reader) + .or(Err(JabberError::Utf8Decode))?; + match value.as_ref() { + "get" => r#type = Some(IQType::Get), + "set" => r#type = Some(IQType::Set), + "result" => r#type = Some(IQType::Result), + "error" => r#type = Some(IQType::Error), + _ => return Err(JabberError::ParseError), + } + } + QName(b"lang") => { + lang = Some( + attribute + .decode_and_unescape_value(&reader) + .or(Err(JabberError::Utf8Decode))? + .into_owned(), + ) + } + _ => return Err(JabberError::UnknownAttribute), + } + } + Ok(()) + })?; + let iq = IQ { + to, + from, + id: id.ok_or(JabberError::NoID)?, + r#type: r#type.ok_or(JabberError::NoType)?, + lang, + child: element.child()?.to_owned(), + }; + return Ok(iq); + } + } + Err(JabberError::ParseError) + } +} diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index c29b1a2..ad9e228 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1,5 +1,7 @@ // use quick_xml::events::BytesDecl; +pub mod bind; +pub mod iq; pub mod sasl; pub mod stream; @@ -128,28 +130,40 @@ impl<'e> Element<'e> { e => Err(ElementError::NotAStart(e.into_owned()).into()), } } -} -/// if there is only one child in the vec of children, will return that element -pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> { - if let Some(children) = &element.children { - if children.len() == 1 { - return Ok(&children[0]); - } else { - return Err(ElementError::MultipleChildren); + /// if there is only one child in the vec of children, will return that element + pub fn child<'p>(&'p self) -> Result<&'p Element<'e>, ElementError<'static>> { + if let Some(children) = &self.children { + if children.len() == 1 { + return Ok(&children[0]); + } else { + return Err(ElementError::MultipleChildren); + } + } + Err(ElementError::NoChildren) + } + + /// returns reference to children + pub fn children<'p>(&'p self) -> Result<&'p Vec>, ElementError<'e>> { + if let Some(children) = &self.children { + return Ok(children); } + Err(ElementError::NoChildren) } - Err(ElementError::NoChildren) } -/// returns reference to children -pub fn children<'p, 'e>( - element: &'p Element<'e>, -) -> Result<&'p Vec>, ElementError<'e>> { - if let Some(children) = &element.children { - return Ok(children); +pub trait IntoElement<'e> { + fn event(&self) -> Event<'e>; + fn children(&self) -> Option>>; +} + +impl<'e, T: IntoElement<'e>> From for Element<'e> { + fn from(value: T) -> Self { + Element { + event: value.event(), + children: value.children(), + } } - Err(ElementError::NoChildren) } #[derive(Debug)] diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index bbf3f41..50ffd83 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -7,6 +7,7 @@ use crate::error::SASLError; use crate::JabberError; use super::Element; +use super::IntoElement; const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; @@ -16,7 +17,7 @@ pub struct Auth<'e> { pub sasl_data: &'e str, } -impl<'e> Auth<'e> { +impl<'e> IntoElement<'e> for Auth<'e> { fn event(&self) -> Event<'e> { let mut start = BytesStart::new("auth"); start.push_attribute(("xmlns", XMLNS)); @@ -34,15 +35,6 @@ impl<'e> Auth<'e> { } } -impl<'e> Into> for Auth<'e> { - fn into(self) -> Element<'e> { - Element { - event: self.event(), - children: self.children(), - } - } -} - #[derive(Debug)] pub struct Challenge { pub sasl_data: Vec, @@ -54,7 +46,7 @@ impl<'e> TryFrom<&Element<'e>> for Challenge { fn try_from(element: &Element<'e>) -> Result { if let Event::Start(start) = &element.event { if start.name() == QName(b"challenge") { - let sasl_data: &Element<'_> = super::child(element)?; + let sasl_data: &Element<'_> = element.child()?; if let Event::Text(sasl_data) = &sasl_data.event { let s = sasl_data.clone(); let s = s.into_inner(); @@ -101,7 +93,7 @@ pub struct Response<'e> { pub sasl_data: &'e str, } -impl<'e> Response<'e> { +impl<'e> IntoElement<'e> for Response<'e> { fn event(&self) -> Event<'e> { let mut start = BytesStart::new("response"); start.push_attribute(("xmlns", XMLNS)); @@ -118,15 +110,6 @@ impl<'e> Response<'e> { } } -impl<'e> Into> for Response<'e> { - fn into(self) -> Element<'e> { - Element { - event: self.event(), - children: self.children(), - } - } -} - #[derive(Debug)] pub struct Success { pub sasl_data: Option>, @@ -139,7 +122,7 @@ impl<'e> TryFrom<&Element<'e>> for Success { match &element.event { Event::Start(start) => { if start.name() == QName(b"success") { - match super::child(element) { + match element.child() { Ok(sasl_data) => { if let Event::Text(sasl_data) = &sasl_data.event { return Ok(Success { diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 66741b8..f85166f 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -175,7 +175,11 @@ impl<'e> TryFrom> for Vec { } features.push(StreamFeature::Sasl(mechanisms)) } - _ => {} + _ => features.push(StreamFeature::Unknown), + }, + Event::Empty(e) => match e.name() { + QName(b"bind") => features.push(StreamFeature::Bind), + _ => features.push(StreamFeature::Unknown), }, _ => features.push(StreamFeature::Unknown), } -- cgit