From ba94ee66fafbabd63d6d1ed5edf435d4c46c6796 Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Fri, 20 Oct 2023 04:51:56 +0100 Subject: WIP: refactor to parse incoming stream as state machine --- src/client/encrypted.rs | 16 +- src/client/mod.rs | 10 +- src/client/unencrypted.rs | 199 ++++++++++---- src/jabber.rs | 18 +- src/jid.rs | 181 +++++++++++++ src/jid/mod.rs | 162 ------------ src/lib.rs | 25 +- src/stanza/bind.rs | 47 ---- src/stanza/iq.rs | 169 ------------ src/stanza/message.rs | 1 + src/stanza/mod.rs | 656 +--------------------------------------------- src/stanza/presence.rs | 1 + src/stanza/sasl.rs | 144 ---------- src/stanza/starttls.rs | 1 + src/stanza/stream.rs | 226 ++++------------ 15 files changed, 399 insertions(+), 1457 deletions(-) create mode 100644 src/jid.rs delete mode 100644 src/jid/mod.rs create mode 100644 src/stanza/message.rs create mode 100644 src/stanza/presence.rs create mode 100644 src/stanza/starttls.rs (limited to 'src') diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index 47b2b2c..263d5ff 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -2,36 +2,26 @@ use std::{collections::BTreeMap, str}; use quick_xml::{ events::{BytesDecl, Event}, - Reader, Writer, + NsReader, Writer, }; use rsasl::prelude::{Mechname, SASLClient}; use tokio::io::{BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio_native_tls::TlsStream; -use crate::stanza::{ - bind::Bind, - iq::IQ, - sasl::{Challenge, Success}, - Element, -}; -use crate::stanza::{ - sasl::{Auth, Response}, - stream::{Stream, StreamFeature}, -}; use crate::Jabber; use crate::JabberError; use crate::Result; pub struct JabberClient<'j> { - pub reader: Reader>>>, + pub reader: NsReader>>>, pub writer: Writer>>, jabber: &'j mut Jabber<'j>, } impl<'j> JabberClient<'j> { pub fn new( - reader: Reader>>>, + reader: NsReader>>>, writer: Writer>>, jabber: &'j mut Jabber<'j>, ) -> Self { diff --git a/src/client/mod.rs b/src/client/mod.rs index 280e0a1..01df4a4 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,22 +1,24 @@ -pub mod encrypted; +// pub mod encrypted; pub mod unencrypted; // use async_trait::async_trait; -use crate::stanza::stream::StreamFeature; +// use crate::stanza::stream::StreamFeature; use crate::JabberError; use crate::Result; pub enum JabberClientType<'j> { - Encrypted(encrypted::JabberClient<'j>), + // Encrypted(encrypted::JabberClient<'j>), Unencrypted(unencrypted::JabberClient<'j>), } impl<'j> JabberClientType<'j> { + /// ensures an encrypted jabber client pub async fn ensure_tls(self) -> Result> { match self { Self::Encrypted(c) => Ok(c), Self::Unencrypted(mut c) => { + c.start_stream().await?; let features = c.get_features().await?; if features.contains(&StreamFeature::StartTls) { Ok(c.starttls().await?) @@ -28,7 +30,7 @@ impl<'j> JabberClientType<'j> { } } -// TODO: jabber client trait over both client types +// TODO: jabber client trait over both client types using macro // #[async_trait] // pub trait JabberTrait { // async fn start_stream(&mut self) -> Result<()>; diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs index 27b0a5f..4aa9c63 100644 --- a/src/client/unencrypted.rs +++ b/src/client/unencrypted.rs @@ -1,27 +1,30 @@ +use std::str; + use quick_xml::{ - events::{BytesDecl, BytesStart, Event}, + events::{BytesStart, Event}, name::QName, - Reader, Writer, + se, NsReader, Writer, }; use tokio::io::{BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; +use try_map::FallibleMapExt; -use crate::stanza::stream::StreamFeature; -use crate::stanza::Element; +use crate::error::JabberError; +use crate::stanza::stream::Stream; +use crate::stanza::DECLARATION; use crate::Jabber; use crate::Result; -use crate::{error::JabberError, stanza::stream::Stream}; pub struct JabberClient<'j> { - reader: Reader>>, + reader: NsReader>>, writer: Writer>, jabber: &'j mut Jabber<'j>, } impl<'j> JabberClient<'j> { pub fn new( - reader: Reader>>, + reader: NsReader>>, writer: Writer>, jabber: &'j mut Jabber<'j>, ) -> Self { @@ -34,60 +37,144 @@ impl<'j> JabberClient<'j> { pub async fn start_stream(&mut self) -> Result<()> { // client to server - let declaration = BytesDecl::new("1.0", None, None); + + // declaration + self.writer.write_event_async(DECLARATION).await?; + + // opening stream element let server = &self.jabber.server.to_owned().try_into()?; - let stream_element = - Stream::new_client(&self.jabber.jid, server, None, Some("en".to_string())); - self.writer - .write_event_async(Event::Decl(declaration)) - .await?; - let stream_element: Element<'_> = stream_element.into(); - stream_element.write_start(&mut self.writer).await?; - // server to client - let mut buf = Vec::new(); - self.reader.read_event_into_async(&mut buf).await?; - let _stream_response = Element::read_start(&mut self.reader).await?; - Ok(()) - } + let stream_element = Stream::new_client(None, server, None, "en"); + se::to_writer_with_root(&mut self.writer, "stream:stream", &stream_element); - pub async fn get_features(&mut self) -> Result> { - Element::read(&mut self.reader).await?.try_into() - } + // server to client - pub async fn starttls(mut self) -> Result> { - let mut starttls_element = BytesStart::new("starttls"); - starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); - self.writer - .write_event_async(Event::Empty(starttls_element)) - .await - .unwrap(); - let mut buf = Vec::new(); - match self.reader.read_event_into_async(&mut buf).await.unwrap() { - Event::Empty(e) => match e.name() { - QName(b"proceed") => { - let connector = TlsConnector::new().unwrap(); - let stream = self - .reader - .into_inner() - .into_inner() - .unsplit(self.writer.into_inner()); - if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) - .connect(&self.jabber.server, stream) - .await - { - let (read, write) = tokio::io::split(tlsstream); - let reader = Reader::from_reader(BufReader::new(read)); - let writer = Writer::new(write); - let mut client = - super::encrypted::JabberClient::new(reader, writer, self.jabber); - client.start_stream().await?; - return Ok(client); + // may or may not send a declaration + let buf = Vec::new(); + let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; + match first_event { + (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { + if let Ok(version) = e.version() { + if version.as_ref() == b"1.0" { + first_event = self.reader.read_resolved_event_into_async(&mut buf).await? + } else { + // todo: error + todo!() } + } else { + first_event = self.reader.read_resolved_event_into_async(&mut buf).await? + } + } + _ => (), + } + + // receive stream element and validate + let stream_response: Stream; + match first_event { + (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { + if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { + // stream_response = Stream::new( + // e.try_get_attribute("from")?.try_map(|attribute| { + // str::from_utf8(attribute.value.as_ref())? + // .try_into()? + // .as_ref() + // })?, + // e.try_get_attribute("to")?.try_map(|attribute| { + // str::from_utf8(attribute.value.as_ref())? + // .try_into()? + // .as_ref() + // })?, + // e.try_get_attribute("id")?.try_map(|attribute| { + // str::from_utf8(attribute.value.as_ref())? + // .try_into()? + // .as_ref() + // })?, + // e.try_get_attribute("version")?.try_map(|attribute| { + // str::from_utf8(attribute.value.as_ref())? + // .try_into()? + // .as_ref() + // })?, + // e.try_get_attribute("lang")?.try_map(|attribute| { + // str::from_utf8(attribute.value.as_ref())? + // .try_into()? + // .as_ref() + // })?, + // ); + return Ok(()); + } else { + return Err(JabberError::BadStream); } - QName(_) => return Err(JabberError::TlsNegotiation), - }, - _ => return Err(JabberError::TlsNegotiation), + } + // TODO: errors for incorrect namespace + (quick_xml::name::ResolveResult::Unbound, Event::Decl(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Start(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::End(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Empty(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Text(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::CData(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Comment(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Decl(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::PI(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::DocType(_)) => todo!(), + (quick_xml::name::ResolveResult::Unknown(_), Event::Eof) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::Start(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::End(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::Empty(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::Text(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::CData(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::Comment(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::PI(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::DocType(_)) => todo!(), + (quick_xml::name::ResolveResult::Unbound, Event::Eof) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::End(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::Empty(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::Text(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::CData(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::Comment(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::Decl(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::PI(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::DocType(_)) => todo!(), + (quick_xml::name::ResolveResult::Bound(_), Event::Eof) => todo!(), } - Err(JabberError::TlsNegotiation) } + + // pub async fn get_features(&mut self) -> Result> { + // Element::read(&mut self.reader).await?.try_into() + // } + + // pub async fn starttls(mut self) -> Result> { + // let mut starttls_element = BytesStart::new("starttls"); + // starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); + // self.writer + // .write_event_async(Event::Empty(starttls_element)) + // .await + // .unwrap(); + // let mut buf = Vec::new(); + // match self.reader.read_event_into_async(&mut buf).await.unwrap() { + // Event::Empty(e) => match e.name() { + // QName(b"proceed") => { + // let connector = TlsConnector::new().unwrap(); + // let stream = self + // .reader + // .into_inner() + // .into_inner() + // .unsplit(self.writer.into_inner()); + // if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) + // .connect(&self.jabber.server, stream) + // .await + // { + // let (read, write) = tokio::io::split(tlsstream); + // let reader = Reader::from_reader(BufReader::new(read)); + // let writer = Writer::new(write); + // let mut client = + // super::encrypted::JabberClient::new(reader, writer, self.jabber); + // client.start_stream().await?; + // return Ok(client); + // } + // } + // QName(_) => return Err(JabberError::TlsNegotiation), + // }, + // _ => return Err(JabberError::TlsNegotiation), + // } + // Err(JabberError::TlsNegotiation) + // } } diff --git a/src/jabber.rs b/src/jabber.rs index 1a7eddb..d48eb9c 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -3,7 +3,7 @@ use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use std::sync::Arc; -use quick_xml::{Reader, Writer}; +use quick_xml::{NsReader, Writer}; use rsasl::prelude::SASLConfig; use tokio::io::BufReader; use tokio::net::TcpStream; @@ -22,7 +22,7 @@ pub struct Jabber<'j> { } impl<'j> Jabber<'j> { - pub fn new(jid: JID, password: String) -> Result { + pub fn user(jid: JID, password: String) -> Result { let server = jid.domainpart.clone(); let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; println!("auth: {:?}", auth); @@ -36,7 +36,7 @@ impl<'j> Jabber<'j> { pub async fn login(&'j mut self) -> Result> { let mut client = self.connect().await?.ensure_tls().await?; - println!("negotiation"); + client.start_stream().await?; client.negotiate().await?; Ok(client) } @@ -106,6 +106,7 @@ impl<'j> Jabber<'j> { socket_addrs } + /// establishes a connection to the server pub async fn connect(&'j mut self) -> Result { for (socket_addr, is_tls) in self.get_sockets().await { println!("trying {}", socket_addr); @@ -118,21 +119,18 @@ impl<'j> Jabber<'j> { .await { let (read, write) = tokio::io::split(stream); - let reader = Reader::from_reader(BufReader::new(read)); + let reader = NsReader::from_reader(BufReader::new(read)); let writer = Writer::new(write); - let mut client = client::encrypted::JabberClient::new(reader, writer, self); - client.start_stream().await?; + let client = client::encrypted::JabberClient::new(reader, writer, self); return Ok(JabberClientType::Encrypted(client)); } } false => { if let Ok(stream) = TcpStream::connect(socket_addr).await { let (read, write) = tokio::io::split(stream); - let reader = Reader::from_reader(BufReader::new(read)); + let reader = NsReader::from_reader(BufReader::new(read)); let writer = Writer::new(write); - let mut client = - client::unencrypted::JabberClient::new(reader, writer, self); - client.start_stream().await?; + let client = client::unencrypted::JabberClient::new(reader, writer, self); return Ok(JabberClientType::Unencrypted(client)); } } diff --git a/src/jid.rs b/src/jid.rs new file mode 100644 index 0000000..65738dc --- /dev/null +++ b/src/jid.rs @@ -0,0 +1,181 @@ +use std::str::FromStr; + +use serde::Serialize; + +#[derive(PartialEq, Debug, Clone)] +pub struct JID { + // TODO: validate localpart (length, char] + pub localpart: Option, + pub domainpart: String, + pub resourcepart: Option, +} + +impl Serialize for JID { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +pub enum JIDError { + NoResourcePart, + ParseError(ParseError), +} + +#[derive(Debug)] +pub enum ParseError { + Empty, + Malformed, +} + +impl JID { + pub fn new( + localpart: Option, + domainpart: String, + resourcepart: Option, + ) -> Self { + Self { + localpart, + domainpart: domainpart.parse().unwrap(), + resourcepart, + } + } + + pub fn as_bare(&self) -> Self { + Self { + localpart: self.localpart.clone(), + domainpart: self.domainpart.clone(), + resourcepart: None, + } + } + + pub fn as_full(&self) -> Result<&Self, JIDError> { + if let Some(_) = self.resourcepart { + Ok(&self) + } else { + Err(JIDError::NoResourcePart) + } + } +} + +impl FromStr for JID { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + let split: Vec<&str> = s.split('@').collect(); + match split.len() { + 0 => Err(ParseError::Empty), + 1 => { + let split: Vec<&str> = split[0].split('/').collect(); + match split.len() { + 1 => Ok(JID::new(None, split[0].to_string(), None)), + 2 => Ok(JID::new( + None, + split[0].to_string(), + Some(split[1].to_string()), + )), + _ => Err(ParseError::Malformed), + } + } + 2 => { + let split2: Vec<&str> = split[1].split('/').collect(); + match split2.len() { + 1 => Ok(JID::new( + Some(split[0].to_string()), + split2[0].to_string(), + None, + )), + 2 => Ok(JID::new( + Some(split[0].to_string()), + split2[0].to_string(), + Some(split2[1].to_string()), + )), + _ => Err(ParseError::Malformed), + } + } + _ => Err(ParseError::Malformed), + } + } +} + +impl TryFrom for JID { + type Error = ParseError; + + fn try_from(value: String) -> Result { + value.parse() + } +} + +impl TryFrom<&str> for JID { + type Error = ParseError; + + fn try_from(value: &str) -> Result { + value.parse() + } +} + +impl std::fmt::Display for JID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}{}{}", + self.localpart.clone().map(|l| l + "@").unwrap_or_default(), + self.domainpart, + self.resourcepart + .clone() + .map(|r| "/".to_owned() + &r) + .unwrap_or_default() + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn jid_to_string() { + assert_eq!( + JID::new(Some("cel".into()), "blos.sm".into(), None).to_string(), + "cel@blos.sm".to_owned() + ); + } + + #[test] + fn parse_full_jid() { + assert_eq!( + "cel@blos.sm/greenhouse".parse::().unwrap(), + JID::new( + Some("cel".into()), + "blos.sm".into(), + Some("greenhouse".into()) + ) + ) + } + + #[test] + fn parse_bare_jid() { + assert_eq!( + "cel@blos.sm".parse::().unwrap(), + JID::new(Some("cel".into()), "blos.sm".into(), None) + ) + } + + #[test] + fn parse_domain_jid() { + assert_eq!( + "component.blos.sm".parse::().unwrap(), + JID::new(None, "component.blos.sm".into(), None) + ) + } + + #[test] + fn parse_full_domain_jid() { + assert_eq!( + "component.blos.sm/bot".parse::().unwrap(), + JID::new(None, "component.blos.sm".into(), Some("bot".into())) + ) + } +} diff --git a/src/jid/mod.rs b/src/jid/mod.rs deleted file mode 100644 index e13fed7..0000000 --- a/src/jid/mod.rs +++ /dev/null @@ -1,162 +0,0 @@ -use std::str::FromStr; - -#[derive(PartialEq, Debug, Clone)] -pub struct JID { - // TODO: validate localpart (length, char] - pub localpart: Option, - pub domainpart: String, - pub resourcepart: Option, -} - -pub enum JIDError { - NoResourcePart, - ParseError(ParseError), -} - -#[derive(Debug)] -pub enum ParseError { - Empty, - Malformed, -} - -impl JID { - pub fn new( - localpart: Option, - domainpart: String, - resourcepart: Option, - ) -> Self { - Self { - localpart, - domainpart: domainpart.parse().unwrap(), - resourcepart, - } - } - - pub fn as_bare(&self) -> Self { - Self { - localpart: self.localpart.clone(), - domainpart: self.domainpart.clone(), - resourcepart: None, - } - } - - pub fn as_full(&self) -> Result<&Self, JIDError> { - if let Some(_) = self.resourcepart { - Ok(&self) - } else { - Err(JIDError::NoResourcePart) - } - } -} - -impl FromStr for JID { - type Err = ParseError; - - fn from_str(s: &str) -> Result { - let split: Vec<&str> = s.split('@').collect(); - match split.len() { - 0 => Err(ParseError::Empty), - 1 => { - let split: Vec<&str> = split[0].split('/').collect(); - match split.len() { - 1 => Ok(JID::new(None, split[0].to_string(), None)), - 2 => Ok(JID::new( - None, - split[0].to_string(), - Some(split[1].to_string()), - )), - _ => Err(ParseError::Malformed), - } - } - 2 => { - let split2: Vec<&str> = split[1].split('/').collect(); - match split2.len() { - 1 => Ok(JID::new( - Some(split[0].to_string()), - split2[0].to_string(), - None, - )), - 2 => Ok(JID::new( - Some(split[0].to_string()), - split2[0].to_string(), - Some(split2[1].to_string()), - )), - _ => Err(ParseError::Malformed), - } - } - _ => Err(ParseError::Malformed), - } - } -} - -impl TryFrom for JID { - type Error = ParseError; - - fn try_from(value: String) -> Result { - value.parse() - } -} - -impl std::fmt::Display for JID { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}{}{}", - self.localpart.clone().map(|l| l + "@").unwrap_or_default(), - self.domainpart, - self.resourcepart - .clone() - .map(|r| "/".to_owned() + &r) - .unwrap_or_default() - ) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn jid_to_string() { - assert_eq!( - JID::new(Some("cel".into()), "blos.sm".into(), None).to_string(), - "cel@blos.sm".to_owned() - ); - } - - #[test] - fn parse_full_jid() { - assert_eq!( - "cel@blos.sm/greenhouse".parse::().unwrap(), - JID::new( - Some("cel".into()), - "blos.sm".into(), - Some("greenhouse".into()) - ) - ) - } - - #[test] - fn parse_bare_jid() { - assert_eq!( - "cel@blos.sm".parse::().unwrap(), - JID::new(Some("cel".into()), "blos.sm".into(), None) - ) - } - - #[test] - fn parse_domain_jid() { - assert_eq!( - "component.blos.sm".parse::().unwrap(), - JID::new(None, "component.blos.sm".into(), None) - ) - } - - #[test] - fn parse_full_domain_jid() { - assert_eq!( - "component.blos.sm/bot".parse::().unwrap(), - JID::new(None, "component.blos.sm".into(), Some("bot".into())) - ) - } -} diff --git a/src/lib.rs b/src/lib.rs index 8162ccc..86da83d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ pub mod jabber; pub mod jid; pub mod stanza; -pub use client::encrypted::JabberClient; +// pub use client::encrypted::JabberClient; pub use error::JabberError; pub use jabber::Jabber; pub use jid::JID; @@ -22,30 +22,9 @@ mod tests { use crate::Jabber; use crate::JID; - // #[tokio::test] - // async fn get_sockets() { - // let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); - // println!("{:?}", jabber.get_sockets().await) - // } - - // #[tokio::test] - // async fn connect() { - // Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) - // .unwrap() - // .connect() - // .await - // .unwrap() - // .ensure_tls() - // .await - // .unwrap() - // .start_stream() - // .await - // .unwrap(); - // } - #[tokio::test] async fn login() { - Jabber::new( + Jabber::user( JID::from_str("test@blos.sm/clown").unwrap(), "slayed".to_owned(), ) diff --git a/src/stanza/bind.rs b/src/stanza/bind.rs index 939a716..8b13789 100644 --- a/src/stanza/bind.rs +++ b/src/stanza/bind.rs @@ -1,48 +1 @@ -use super::{Element, ElementParseError}; -use crate::{JabberError, JID}; -const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-bind"; - -pub struct Bind { - pub resource: Option, - pub jid: Option, -} - -impl From for Element { - fn from(bind: Bind) -> Self { - let bind_element = Element::new("bind", None, XMLNS); - bind_element.push_namespace_declaration((None, XMLNS)); - if let Some(resource) = bind.resource { - let resource_element = Element::new("resource", None, XMLNS); - resource_element.push_child(resource); - bind_element.push_child(resource_element) - } - if let Some(jid) = bind.jid { - let jid_element = Element::new("jid", None, XMLNS); - jid_element.push_child(jid); - bind_element.push_child(jid_element) - } - bind_element - } -} - -impl TryFrom for Bind { - type Error = JabberError; - - fn try_from(element: Element) -> Result { - if element.namespace() == XMLNS && element.localname() == "bind" { - let (resource, jid); - let child: &Element = element.child()?; - if child.namespace() == XMLNS { - match child.localname() { - "resource" => Bind::new(Some( - child - .text_content()? - .first() - .ok_or(ElementParseError::NoContent)?, - )), - } - } - } - } -} diff --git a/src/stanza/iq.rs b/src/stanza/iq.rs index 6c7dee3..8b13789 100644 --- a/src/stanza/iq.rs +++ b/src/stanza/iq.rs @@ -1,170 +1 @@ -use nanoid::nanoid; -use quick_xml::{ - events::{BytesStart, Event}, - name::QName, - Reader, Writer, -}; -use crate::{JabberClient, JabberError, JID}; - -use crate::Result; - -#[derive(Debug)] -pub struct IQ { - to: Option, - from: Option, - id: String, - r#type: IQType, - lang: Option, - child: Element<'static>, -} - -#[derive(Debug)] -enum IQType { - Get, - Set, - Result, - Error, -} - -impl IQ { - pub async fn set<'j, R: IntoElement<'static>>( - client: &mut JabberClient<'j>, - to: Option, - from: Option, - element: R, - ) -> Result> { - let id = nanoid!(); - let iq = IQ { - to, - from, - id: id.clone(), - r#type: IQType::Set, - lang: None, - child: Element::from(element), - }; - println!("{:?}", iq); - let iq = Element::from(iq); - println!("{:?}", iq); - iq.write(&mut client.writer).await?; - let result = Element::read(&mut client.reader).await?; - let iq = IQ::try_from(result)?; - if iq.id == id { - return Ok(iq.child); - } - Err(JabberError::IDMismatch) - } -} - -impl<'e> IntoElement<'e> for IQ { - fn event(&self) -> quick_xml::events::Event<'e> { - let mut start = BytesStart::new("iq"); - if let Some(to) = &self.to { - start.push_attribute(("to", to.to_string().as_str())); - } - if let Some(from) = &self.from { - start.push_attribute(("from", from.to_string().as_str())); - } - start.push_attribute(("id", self.id.as_str())); - match self.r#type { - IQType::Get => start.push_attribute(("type", "get")), - IQType::Set => start.push_attribute(("type", "set")), - IQType::Result => start.push_attribute(("type", "result")), - IQType::Error => start.push_attribute(("type", "error")), - } - if let Some(lang) = &self.lang { - start.push_attribute(("from", lang.to_string().as_str())); - } - - quick_xml::events::Event::Start(start) - } - - fn children(&self) -> Option>> { - Some(vec![self.child.clone()]) - } -} - -impl TryFrom> for IQ { - type Error = JabberError; - - fn try_from(element: Element<'static>) -> std::result::Result { - if let Event::Start(start) = &element.event { - if start.name() == QName(b"iq") { - let mut to: Option = None; - let mut from: Option = None; - let mut id = None; - let mut r#type = None; - let mut lang = None; - start - .attributes() - .into_iter() - .try_for_each(|attribute| -> Result<()> { - if let Ok(attribute) = attribute { - let buf: Vec = Vec::new(); - let reader = Reader::from_reader(buf); - match attribute.key { - QName(b"to") => { - to = Some( - attribute - .decode_and_unescape_value(&reader) - .or(Err(JabberError::Utf8Decode))? - .into_owned() - .try_into()?, - ) - } - QName(b"from") => { - from = Some( - attribute - .decode_and_unescape_value(&reader) - .or(Err(JabberError::Utf8Decode))? - .into_owned() - .try_into()?, - ) - } - QName(b"id") => { - id = Some( - attribute - .decode_and_unescape_value(&reader) - .or(Err(JabberError::Utf8Decode))? - .into_owned(), - ) - } - QName(b"type") => { - let value = attribute - .decode_and_unescape_value(&reader) - .or(Err(JabberError::Utf8Decode))?; - match value.as_ref() { - "get" => r#type = Some(IQType::Get), - "set" => r#type = Some(IQType::Set), - "result" => r#type = Some(IQType::Result), - "error" => r#type = Some(IQType::Error), - _ => return Err(JabberError::ParseError), - } - } - QName(b"lang") => { - lang = Some( - attribute - .decode_and_unescape_value(&reader) - .or(Err(JabberError::Utf8Decode))? - .into_owned(), - ) - } - _ => return Err(JabberError::UnknownAttribute), - } - } - Ok(()) - })?; - let iq = IQ { - to, - from, - id: id.ok_or(JabberError::NoID)?, - r#type: r#type.ok_or(JabberError::NoType)?, - lang, - child: element.child()?.to_owned(), - }; - return Ok(iq); - } - } - Err(JabberError::ParseError) - } -} diff --git a/src/stanza/message.rs b/src/stanza/message.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/stanza/message.rs @@ -0,0 +1 @@ + diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index 13fc31e..c5a6da3 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1,657 +1,11 @@ -// use quick_xml::events::BytesDecl; - pub mod bind; pub mod iq; +pub mod message; +pub mod presence; pub mod sasl; +pub mod starttls; pub mod stream; -use std::collections::BTreeMap; -use std::str; - -// const DECLARATION: BytesDecl<'_> = BytesDecl::new("1.0", None, None); -use async_recursion::async_recursion; -use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event}; -use quick_xml::name::PrefixDeclaration; -use quick_xml::{Reader, Writer}; -use tokio::io::{AsyncBufRead, AsyncWrite}; - -use crate::{JabberError, Result}; - -#[derive(Clone, Debug)] -/// represents an xml element as a tree of nodes -pub struct Element { - /// element prefix - /// e.g. `foo` in ``. - prefix: Option, - /// element name - /// e.g. `bar` in ``. - localname: String, - /// qualifying namespace - /// an element must be qualified by a namespace - /// e.g. for `` in - /// ``` - /// - /// - /// - /// - /// zlib - /// lzw - /// - /// - /// - /// ``` - /// would be `"http://etherx.jabber.org/streams"` but for - /// ``` - /// - /// - /// - /// - /// zlib - /// lzw - /// - /// - /// - /// ``` - /// would be `"jabber:client"` - namespace: String, - /// all namespaces applied to element - /// e.g. for `` in - /// ``` - /// - /// - /// - /// zlib - /// lzw - /// - /// - /// ``` - /// would be `[(None, "urn:ietf:params:xml:ns:xmpp-bind")]` despite - /// `(Some("stream"), "http://etherx.jabber.org/streams")` also being available - // TODO: maybe not even needed, as can calculate when writing which namespaces need to be declared - // but then can't have unused namespace on element, confusing. - namespace_declarations: Box, String>>, - /// element attributes - attributes: Box>, - // children elements namespaces contain their parents' namespaces - children: Box>, -} - -#[derive(Clone, Debug)] -pub enum Node { - Element(Element), - Text(String), - Unknown, -} - -impl From for Node { - fn from(element: Element) -> Self { - Self::Element(element) - } -} - -impl From for Node { - fn from(text: S) -> Self { - Self::Text(text.to_string()) - } -} - -impl<'s> From<&Node> for Vec> { - fn from(node: &Node) -> Self { - match node { - Node::Element(e) => e.into(), - Node::Text(t) => vec![Event::Text(BytesText::new(t))], - Unknown => vec![], - } - } -} - -impl Element { - /// returns the fully qualified name - /// e.g. `foo:bar` in - /// ``. - pub fn name(&self) -> &str { - if let Some(prefix) = self.prefix { - format!("{}:{}", prefix, self.localname).as_str() - } else { - &self.localname - } - } - - /// returns the localname. - /// e.g. `bar` in `` - pub fn localname(&self) -> &str { - &self.localname - } - - /// returns the prefix. - /// e.g. `foo` in ``. returns None if there is - /// no prefix. - pub fn prefix(&self) -> Option<&str> { - self.prefix - } - - /// returns the namespace which applies to the current element, e.g. for - /// `` - /// it will be `foo` but for - /// `` - /// it will be `bar`. - pub fn namespace(&self) -> &str { - &self.namespace - } -} - -impl<'s> From<&Element> for Vec> { - fn from(element: &Element) -> Self { - let name = element.name(); - - let event = BytesStart::new(name); - - // namespace declarations - let namespace_declarations = element.namespace_declarations.iter().map(|declaration| { - let (prefix, namespace) = declaration; - match prefix { - Some(prefix) => return (format!("xmlns:{}", prefix).as_str(), *namespace), - None => return ("xmlns", *namespace), - } - }); - let event = event.with_attributes(namespace_declarations); - - // attributes - let event = event.with_attributes(element.attributes.into_iter()); - - match element.children.is_empty() { - true => return vec![Event::Empty(event)], - false => { - return { - let start: Vec> = vec![Event::Start(event)]; - let start_and_content: Vec> = start - .into_iter() - .chain({ - let u = element.children.iter().fold( - Vec::new(), - |acc: Vec>, child: &Node<'s>| { - acc.into_iter() - .chain(Into::>>::into(child).into_iter()) - .collect() - }, - ); - u - }) - .collect(); - let full: Vec> = start_and_content - .into_iter() - .chain(vec![Event::End(BytesEnd::new(name))]) - .collect(); - full - } - } - } - } -} - -impl Element { - /// if there is only one child in the vec of children, will return that element - pub fn child(&self) -> Result<&Node> { - if self.children.len() == 1 { - Ok(&self.children[0]) - } else if self.children.len() > 1 { - Err(ElementError::MultipleChildren.into()) - } else { - Err(ElementError::NoChildren.into()) - } - } - - /// returns reference to children - pub fn children(&self) -> Result<&Vec> { - if !self.children.is_empty() { - Ok(&self.children) - } else { - Err(ElementError::NoChildren.into()) - } - } - - /// returns text content, error if there is none - pub fn text_content(&self) -> Result> { - let mut text = Vec::new(); - for node in *self.children { - match node { - Node::Text(t) => text.push(t), - _ => {} - } - } - if text.is_empty() { - return Err(ElementError::NotText.into()); - } - Ok(text) - } - - /// returns whether or not the element is qualified by a namespace, either declared - /// by a parent, or itself. - fn namespace_qualified>( - &self, - namespace_context: &BTreeMap, S>, - ) -> bool { - // create a new local_namespaces combining that in the context and those declared within the element - let mut local_namespaces = *namespace_context.clone(); - self.namespace_declarations.iter().for_each(|prefix, declaration| local_namespaces.insert(prefix, declaration)); - - if let Some(namespace) = local_namespaces.get(self.prefix) { - if namespace != self.namespace { - return false; - } - } else { - return false; - }; - - for child in *self.children { - if child.namespace_qualified(&local_namespaces) == false { - return false; - } - } - - true - } - - - /// writes an element to a writer. the element's namespace must be qualified by the - /// context given in `local_namespaces` or the element's internal namespace declarations - pub async fn write, W: AsyncWrite + Unpin + Send>( - &self, - writer: &mut Writer, - local_namespaces: &BTreeMap, S>, - ) -> Result<()> { - // TODO: instead of having a check if namespace qualified function, have the namespace declarations be added if needed given the context when converting from `Element` to `Event`s - if self.namespace_qualified(local_namespaces) { - let events: Vec = self.into(); - for event in events { - writer.write_event_async(event).await? - } - Ok(()) - } else { - Err(ElementError::NamespaceNotQualified.into()) - } - } - - pub async fn write_start, W: AsyncWrite + Unpin + Send>( - &self, - writer: &mut Writer, - local_namespaces: &BTreeMap, S>, - ) -> Result<()> { - if self.namespace_qualified(local_namespaces) { - let mut event = BytesStart::new(self.name()); - - // namespace declarations - self.namespace_declarations.iter().for_each(|declaration| { - let (prefix, namespace) = declaration; - match prefix { - Some(prefix) => { - event.push_attribute((format!("xmlns:{}", prefix).as_str(), *namespace)) - } - None => event.push_attribute(("xmlns", *namespace)), - } - }); - - // attributes - let event = - event.with_attributes(self.attributes.iter().map(|(attr, value)| (*attr, *value))); - - writer.write_event_async(Event::Start(event)).await?; - - Ok(()) - } else { - Err(ElementError::NamespaceNotQualified.into()) - } - } - - pub async fn write_end( - &self, - writer: &mut Writer, - ) -> Result<()> { - let event = BytesEnd::new(self.name()); - writer.write_event_async(Event::End(event)).await?; - Ok(()) - } - - #[async_recursion] - pub async fn read, R: AsyncBufRead + Unpin + Send>( - reader: &mut Reader, - local_namespaces: &BTreeMap, S>, - ) -> Result { - let node = Node::read_recursive(reader, local_namespaces) - .await? - .ok_or(JabberError::UnexpectedEnd)?; - match node { - Node::Element(e) => Ok(e), - Node::Text(_) => Err(JabberError::UnexpectedText), - Node::Unknown => Err(JabberError::UnexpectedElement), - } - } - - pub async fn read_start, R: AsyncBufRead + Unpin + Send>( - reader: &mut Reader, - local_namespaces: &BTreeMap, S>, - ) -> Result { - let buf = Vec::new(); - let event = reader.read_event_into_async(&mut buf).await?; - match event { - Event::Start(e) => { - let prefix = e.name().prefix().map(|prefix| prefix.into_inner()); - let converted_prefix; - if let Some(raw_prefix) = prefix { - converted_prefix = Some(str::from_utf8(raw_prefix)?) - } - let prefix = converted_prefix; - - let localname = str::from_utf8(e.local_name().into_inner())?.to_owned(); - - let mut local_namespaces = local_namespaces.clone(); - let mut namespace_declarations = BTreeMap::new(); - let attributes = BTreeMap::new(); - - for attribute in e.attributes() { - let attribute = attribute?; - if let Some(prefix_declaration) = attribute.key.as_namespace_binding() { - match prefix_declaration { - PrefixDeclaration::Default => { - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(None, value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(None, value); - } - PrefixDeclaration::Named(prefix) => { - let key = str::from_utf8(prefix)?; - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(Some(key), value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(Some(key), value); - } - } - } else { - if let Some(_) = attributes.insert( - str::from_utf8(attribute.key.into_inner())?, - str::from_utf8(attribute.value.as_ref())?, - ) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - } - } - - let namespace = *local_namespaces - .get(&prefix) - .ok_or(ElementParseError::NoNamespace)?; - - let mut children = Vec::new(); - - Ok(Self { - prefix, - localname, - namespace, - namespace_declarations: Box::new(namespace_declarations), - attributes: Box::new(attributes), - children: Box::new(children), - }) - } - e => Err(ElementError::NotAStart(e).into()), - } - } -} - -impl Node { - #[async_recursion] - async fn read_recursive, R: AsyncBufRead + Unpin + Send>( - reader: &mut Reader, - local_namespaces: &BTreeMap, S>, - ) -> Result> { - let mut buf = Vec::new(); - let event = reader.read_event_into_async(&mut buf).await?; - match event { - Event::Empty(e) => { - let prefix = e.name().prefix().map(|prefix| prefix.into_inner()); - let converted_prefix; - if let Some(raw_prefix) = prefix { - converted_prefix = Some(str::from_utf8(raw_prefix)?) - } - let prefix = converted_prefix; - - let localname = str::from_utf8(e.local_name().into_inner())?.to_owned(); - - let mut local_namespaces = local_namespaces.clone(); - let mut namespace_declarations = BTreeMap::new(); - let attributes = BTreeMap::new(); - - for attribute in e.attributes() { - let attribute = attribute?; - if let Some(prefix_declaration) = attribute.key.as_namespace_binding() { - match prefix_declaration { - PrefixDeclaration::Default => { - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(None, value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(None, value); - } - PrefixDeclaration::Named(prefix) => { - let key = str::from_utf8(prefix)?; - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(Some(key), value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(Some(key), value); - } - } - } else { - if let Some(_) = attributes.insert( - str::from_utf8(attribute.key.into_inner())?, - str::from_utf8(attribute.value.as_ref())?, - ) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - } - } - - let namespace = *local_namespaces - .get(&prefix) - .ok_or(ElementParseError::NoNamespace)?; - - let mut children = Vec::new(); - - Ok(Some(Self::Element(Element { - prefix, - localname, - namespace, - namespace_declarations: Box::new(namespace_declarations), - attributes: Box::new(attributes), - children: Box::new(children), - }))) - } - Event::Start(e) => { - let prefix = e.name().prefix().map(|prefix| prefix.into_inner()); - let converted_prefix; - if let Some(raw_prefix) = prefix { - converted_prefix = Some(str::from_utf8(raw_prefix)?) - } - let prefix = converted_prefix; - - let localname = str::from_utf8(e.local_name().into_inner())?.to_owned(); - - let mut local_namespaces = local_namespaces.clone(); - let mut namespace_declarations = BTreeMap::new(); - let attributes = BTreeMap::new(); - - for attribute in e.attributes() { - let attribute = attribute?; - if let Some(prefix_declaration) = attribute.key.as_namespace_binding() { - match prefix_declaration { - PrefixDeclaration::Default => { - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(None, value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(None, value); - } - PrefixDeclaration::Named(prefix) => { - let key = str::from_utf8(prefix)?; - let value = str::from_utf8(attribute.value.as_ref())?; - if let Some(_) = namespace_declarations.insert(Some(key), value) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - local_namespaces.insert(Some(key), value); - } - } - } else { - if let Some(_) = attributes.insert( - str::from_utf8(attribute.key.into_inner())?, - str::from_utf8(attribute.value.as_ref())?, - ) { - return Err(ElementParseError::DuplicateAttribute.into()); - }; - } - } - - let namespace = *local_namespaces - .get(&prefix) - .ok_or(ElementParseError::NoNamespace)?; - - let mut children = Vec::new(); - while let Some(child_node) = Node::read_recursive(reader, &local_namespaces).await? - { - children.push(child_node) - } - - let mut children = Vec::new(); - - Ok(Some(Self::Element(Element { - prefix, - localname, - namespace, - namespace_declarations: Box::new(namespace_declarations), - attributes: Box::new(attributes), - children: Box::new(children), - }))) - } - Event::End(_) => Ok(None), - Event::Text(e) => Ok(Some(Self::Text(e.unescape()?.as_ref().to_string()))), - e => Ok(Some(Self::Unknown)), - } - } - - fn namespace_qualified>( - &self, - namespace_context: &BTreeMap, S>, - ) -> bool { - match self { - Self::Element(e) => e.namespace_qualified(namespace_context), - _ => true, - } - } -} - -pub enum NodeBuilder { - Text(String), - Element(ElementBuilder), -} - -pub struct ElementBuilder { - localname: String, - prefix: Option, - namespace: String, - namespace_declarations: BTreeMap, String>, - attributes: BTreeMap, - children: Vec, -} - -impl ElementBuilder { - pub fn new(localname: S, prefix: Option, namespace: S) -> Self { - Self { - prefix, - localname, - namespace, - namespace_declarations: Box::new(BTreeMap::new()), - attributes: Box::new(BTreeMap::new()), - children: Box::new(Vec::new()), - } - } - - pub fn push_namespace_declaration( - &mut self, - (prefix, namespace): (Option, S), - ) -> Option { - self.namespace_declarations.insert(prefix, namespace) - } - - pub fn push_attribute(&mut self, (key, value): (S, S)) -> Option { - self.attributes.insert(key, value) - } - - pub fn push_child(&mut self, child: Node) { - self.children.push(child) - } - - /// checks if there is a namespace conflict within the element being built - pub fn namespace_conflict>( - &self - ) -> bool { - self.namespace_conflict_recursive(&BTreeMap::new()) - } - - fn namespace_conflict_recursive>( - &self, - parent_namespaces: &BTreeMap, S>, - ) -> bool { - // create a new local_namespaces combining that in the context and those declared within the element - let mut local_namespaces = *parent_namespaces.clone(); - self.namespace_declarations.iter().for_each(|prefix, declaration| local_namespaces.insert(prefix, declaration)); - - if let Some(namespace) = local_namespaces.get(self.prefix) { - if namespace != self.namespace { - return false; - } - } else { - return false; - }; - - for child in *self.children { - if child.namespace_conflict(&local_namespaces) == false { - return false; - } - } - - true - } - - // check for possible conflicts in namespace - pub fn build(self) -> Result { - for child in self.children { - match child { - Node::Element(e) => { - if !e.namespace_conflict() - } - } - } - Element { - prefix: self.prefix, - localname: self.localname, - namespace: self.namespace, - namespace_declarations: self.namespace_declarations, - attributes: self.attributes, - children: self.children, - } - } -} - -#[derive(Debug)] -pub enum ElementError<'e> { - NotAStart(Event<'e>), - NotText, - NoChildren, - NamespaceNotQualified, - MultipleChildren, -} +use quick_xml::events::{BytesDecl, Event}; -#[derive(Debug)] -pub enum ElementParseError { - DuplicateAttribute, - NoNamespace, -} +pub static DECLARATION: Event = Event::Decl(BytesDecl::new("1.0", None, None)); diff --git a/src/stanza/presence.rs b/src/stanza/presence.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/stanza/presence.rs @@ -0,0 +1 @@ + diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 20cd063..8b13789 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1,145 +1 @@ -use quick_xml::{ - events::{BytesStart, BytesText, Event}, - name::QName, -}; -use crate::error::SASLError; -use crate::JabberError; - -use super::Element; - -const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; - -#[derive(Debug)] -pub struct Auth<'e> { - pub mechanism: &'e str, - pub sasl_data: &'e str, -} - -impl<'e> IntoElement<'e> for Auth<'e> { - fn event(&self) -> Event<'e> { - let mut start = BytesStart::new("auth"); - start.push_attribute(("xmlns", XMLNS)); - start.push_attribute(("mechanism", self.mechanism)); - Event::Start(start) - } - - fn children(&self) -> Option>> { - let sasl = BytesText::from_escaped(self.sasl_data); - let sasl = Element { - event: Event::Text(sasl), - children: None, - }; - Some(vec![sasl]) - } -} - -#[derive(Debug)] -pub struct Challenge { - pub sasl_data: Vec, -} - -impl<'e> TryFrom<&Element<'e>> for Challenge { - type Error = JabberError; - - fn try_from(element: &Element<'e>) -> Result { - if let Event::Start(start) = &element.event { - if start.name() == QName(b"challenge") { - let sasl_data: &Element<'_> = element.child()?; - if let Event::Text(sasl_data) = &sasl_data.event { - let s = sasl_data.clone(); - let s = s.into_inner(); - let s = s.to_vec(); - return Ok(Challenge { sasl_data: s }); - } - } - } - Err(SASLError::NoChallenge.into()) - } -} - -// impl<'e> TryFrom> for Challenge { -// type Error = JabberError; - -// fn try_from(element: Element<'e>) -> Result { -// if let Event::Start(start) = &element.event { -// if start.name() == QName(b"challenge") { -// println!("one"); -// if let Some(children) = element.children.as_deref() { -// if children.len() == 1 { -// let sasl_data = children.first().unwrap(); -// if let Event::Text(sasl_data) = &sasl_data.event { -// return Ok(Challenge { -// sasl_data: sasl_data.clone().into_inner().to_vec(), -// }); -// } else { -// return Err(SASLError::NoChallenge.into()); -// } -// } else { -// return Err(SASLError::NoChallenge.into()); -// } -// } else { -// return Err(SASLError::NoChallenge.into()); -// } -// } -// } -// Err(SASLError::NoChallenge.into()) -// } -// } - -#[derive(Debug)] -pub struct Response<'e> { - pub sasl_data: &'e str, -} - -impl<'e> IntoElement<'e> for Response<'e> { - fn event(&self) -> Event<'e> { - let mut start = BytesStart::new("response"); - start.push_attribute(("xmlns", XMLNS)); - Event::Start(start) - } - - fn children(&self) -> Option>> { - let sasl = BytesText::from_escaped(self.sasl_data); - let sasl = Element { - event: Event::Text(sasl), - children: None, - }; - Some(vec![sasl]) - } -} - -#[derive(Debug)] -pub struct Success { - pub sasl_data: Option>, -} - -impl<'e> TryFrom<&Element<'e>> for Success { - type Error = JabberError; - - fn try_from(element: &Element<'e>) -> Result { - match &element.event { - Event::Start(start) => { - if start.name() == QName(b"success") { - match element.child() { - Ok(sasl_data) => { - if let Event::Text(sasl_data) = &sasl_data.event { - return Ok(Success { - sasl_data: Some(sasl_data.clone().into_inner().to_vec()), - }); - } - } - Err(_) => return Ok(Success { sasl_data: None }), - }; - } - } - Event::Empty(empty) => { - if empty.name() == QName(b"success") { - return Ok(Success { sasl_data: None }); - } - } - _ => {} - } - Err(SASLError::NoSuccess.into()) - } -} diff --git a/src/stanza/starttls.rs b/src/stanza/starttls.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/stanza/starttls.rs @@ -0,0 +1 @@ + diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index f85166f..07f7e6e 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,190 +1,60 @@ -use std::str; - -use quick_xml::{ - events::{BytesStart, Event}, - name::QName, -}; - -use super::Element; -use crate::{JabberError, Result, JID}; - -const XMLNS_STREAM: &str = "http://etherx.jabber.org/streams"; -const VERSION: &str = "1.0"; - -enum XMLNS { - Client, - Server, -} - -impl From for &str { - fn from(xmlns: XMLNS) -> Self { - match xmlns { - XMLNS::Client => return "jabber:client", - XMLNS::Server => return "jabber:server", - } - } -} - -impl TryInto for &str { - type Error = JabberError; - - fn try_into(self) -> Result { - match self { - "jabber:client" => Ok(XMLNS::Client), - "jabber:server" => Ok(XMLNS::Server), - _ => Err(JabberError::UnknownNamespace), - } - } -} - -pub struct Stream { - from: Option, - id: Option, - to: Option, - version: Option, - lang: Option, - ns: XMLNS, +use serde::Serialize; + +use crate::JID; + +pub static XMLNS: &str = "http://etherx.jabber.org/streams"; +pub static XMLNS_CLIENT: &str = "jabber:client"; + +// MUST be qualified by stream namespace +#[derive(Serialize)] +pub struct Stream<'s> { + #[serde(rename = "@from")] + from: Option<&'s JID>, + #[serde(rename = "@to")] + to: Option<&'s JID>, + #[serde(rename = "@id")] + id: Option<&'s str>, + #[serde(rename = "@version")] + version: Option<&'s str>, + // TODO: lang enum + #[serde(rename = "@lang")] + lang: Option<&'s str>, + #[serde(rename = "@xmlns")] + xmlns: &'s str, + #[serde(rename = "@xmlns:stream")] + xmlns_stream: &'s str, } impl Stream { - pub fn new_client(from: &JID, to: &JID, id: Option, lang: Option) -> Self { + pub fn new( + from: Option<&JID>, + to: Option<&JID>, + id: Option<&str>, + version: Option<&str>, + lang: Option<&str>, + ) -> Self { Self { - from: Some(from.clone()), + from, + to, id, - to: Some(to.clone()), - version: Some(VERSION.to_owned()), + version, lang, - ns: XMLNS::Client, - } - } - - fn event(&self) -> Event<'static> { - let mut start = BytesStart::new("stream:stream"); - if let Some(from) = &self.from { - start.push_attribute(("from", from.to_string().as_str())); - } - if let Some(id) = &self.id { - start.push_attribute(("id", id.as_str())); - } - if let Some(to) = &self.to { - start.push_attribute(("to", to.to_string().as_str())); - } - if let Some(version) = &self.version { - start.push_attribute(("version", version.to_string().as_str())); - } - if let Some(lang) = &self.lang { - start.push_attribute(("xml:lang", lang.as_str())); - } - match &self.ns { - XMLNS::Client => start.push_attribute(("xmlns", XMLNS::Client.into())), - XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())), - } - start.push_attribute(("xmlns:stream", XMLNS_STREAM)); - Event::Start(start) - } -} - -impl<'e> Into> for Stream { - fn into(self) -> Element<'e> { - Element { - event: self.event(), - children: None, + xmlns: XMLNS_CLIENT, + xmlns_stream: XMLNS, } } -} - -impl<'e> TryFrom> for Stream { - type Error = JabberError; - fn try_from(value: Element<'e>) -> Result { - let (mut from, mut id, mut to, mut version, mut lang, mut ns) = - (None, None, None, None, None, XMLNS::Client); - if let Event::Start(e) = value.event.as_ref() { - for attribute in e.attributes() { - let attribute = attribute?; - match attribute.key { - QName(b"from") => { - from = Some(str::from_utf8(&attribute.value)?.to_string().try_into()?); - } - QName(b"id") => { - id = Some(str::from_utf8(&attribute.value)?.to_owned()); - } - QName(b"to") => { - to = Some(str::from_utf8(&attribute.value)?.to_string().try_into()?); - } - QName(b"version") => { - version = Some(str::from_utf8(&attribute.value)?.to_owned()); - } - QName(b"lang") => { - lang = Some(str::from_utf8(&attribute.value)?.to_owned()); - } - QName(b"xmlns") => { - ns = str::from_utf8(&attribute.value)?.try_into()?; - } - _ => { - println!("unknown attribute: {:?}", attribute) - } - } - } - Ok(Stream { - from, - id, - to, - version, - lang, - ns, - }) - } else { - Err(JabberError::ParseError) - } - } -} - -#[derive(PartialEq, Debug)] -pub enum StreamFeature { - StartTls, - Sasl(Vec), - Bind, - Unknown, -} - -impl<'e> TryFrom> for Vec { - type Error = JabberError; - - fn try_from(features_element: Element) -> Result { - let mut features = Vec::new(); - if let Some(children) = features_element.children { - for feature_element in children { - match feature_element.event { - Event::Start(e) => match e.name() { - QName(b"starttls") => features.push(StreamFeature::StartTls), - QName(b"mechanisms") => { - let mut mechanisms = Vec::new(); - if let Some(children) = feature_element.children { - for mechanism_element in children { - if let Some(children) = mechanism_element.children { - for mechanism_text in children { - match mechanism_text.event { - Event::Text(e) => mechanisms - .push(str::from_utf8(e.as_ref())?.to_owned()), - _ => {} - } - } - } - } - } - features.push(StreamFeature::Sasl(mechanisms)) - } - _ => features.push(StreamFeature::Unknown), - }, - Event::Empty(e) => match e.name() { - QName(b"bind") => features.push(StreamFeature::Bind), - _ => features.push(StreamFeature::Unknown), - }, - _ => features.push(StreamFeature::Unknown), - } - } + /// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute. + /// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection. + pub fn new_client(from: Option<&JID>, to: &JID, id: Option<&str>, lang: &str) -> Self { + Self { + from, + to: Some(to), + id, + version: Some("1.0"), + lang: Some(lang), + xmlns: XMLNS_CLIENT, + xmlns_stream: XMLNS, } - Ok(features) } } -- cgit