summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client/encrypted.rs216
-rw-r--r--src/client/mod.rs17
-rw-r--r--src/client/unencrypted.rs76
-rw-r--r--src/element.rs108
-rw-r--r--src/error.rs31
-rw-r--r--src/jabber.rs16
-rw-r--r--src/jid/mod.rs2
-rw-r--r--src/lib.rs1
-rw-r--r--src/stanza/mod.rs4
-rw-r--r--src/stanza/sasl.rs24
-rw-r--r--src/stanza/stream.rs197
11 files changed, 383 insertions, 309 deletions
diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs
index a4bf0d1..76f600c 100644
--- a/src/client/encrypted.rs
+++ b/src/client/encrypted.rs
@@ -1,35 +1,26 @@
-use std::str;
-
use quick_xml::{
- de::Deserializer,
- events::{BytesDecl, BytesStart, Event},
- name::QName,
- se::Serializer,
+ events::{BytesDecl, Event},
Reader, Writer,
};
-use rsasl::prelude::{Mechname, SASLClient};
-use serde::{Deserialize, Serialize};
-use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
+use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
-use crate::stanza::{
- sasl::{Auth, Challenge, Mechanisms},
- stream::{StreamFeature, StreamFeatures},
-};
+use crate::element::Element;
+use crate::stanza::stream::{Stream, StreamFeature};
use crate::Jabber;
use crate::Result;
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
- writer: WriteHalf<TlsStream<TcpStream>>,
+ writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
- writer: WriteHalf<TlsStream<TcpStream>>,
+ writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
@@ -40,90 +31,29 @@ impl<'j> JabberClient<'j> {
}
pub async fn start_stream(&mut self) -> Result<()> {
+ // client to server
let declaration = BytesDecl::new("1.0", None, None);
- let mut stream_element = BytesStart::new("stream:stream");
- stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
- stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
- stream_element.push_attribute(("version", "1.0"));
- stream_element.push_attribute(("xml:lang", "en"));
- stream_element.push_attribute(("xmlns", "jabber:client"));
- stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
- let mut writer = Writer::new(&mut self.writer);
- writer.write_event_async(Event::Decl(declaration)).await;
- writer.write_event_async(Event::Start(stream_element)).await;
+ let server = &self.jabber.server.to_owned().try_into()?;
+ let stream_element =
+ Stream::new_client(&self.jabber.jid, server, None, Some("en".to_string()));
+ self.writer
+ .write_event_async(Event::Decl(declaration))
+ .await;
+ let stream_element: Element<'_> = stream_element.into();
+ stream_element.write_start(&mut self.writer).await?;
+ // server to client
let mut buf = Vec::new();
- loop {
- match self.reader.read_event_into_async(&mut buf).await.unwrap() {
- Event::Start(e) => {
- println!("{:?}", e);
- break;
- }
- e => println!("decl: {:?}", e),
- };
- }
+ self.reader.read_event_into_async(&mut buf).await?;
+ let _stream_response = Element::read_start(&mut self.reader).await?;
Ok(())
}
- pub async fn get_node<'a>(&mut self) -> Result<String> {
- let mut buf = Vec::new();
- let mut txt = Vec::new();
- let mut qname_set = false;
- let mut qname: Option<Vec<u8>> = None;
- loop {
- match self.reader.read_event_into_async(&mut buf).await? {
- Event::Start(e) => {
- if !qname_set {
- qname = Some(e.name().into_inner().to_owned());
- qname_set = true;
- }
- txt.push(b'<');
- txt = txt
- .into_iter()
- .chain(buf.to_owned())
- .chain(vec![b'>'])
- .collect();
- }
- Event::End(e) => {
- let mut end = false;
- if e.name() == QName(qname.as_deref().unwrap()) {
- end = true;
- }
- txt.push(b'<');
- txt = txt
- .into_iter()
- .chain(buf.to_owned())
- .chain(vec![b'>'])
- .collect();
- if end {
- break;
- }
- }
- Event::Text(_e) => {
- txt = txt.into_iter().chain(buf.to_owned()).collect();
- }
- _ => {
- txt.push(b'<');
- txt = txt
- .into_iter()
- .chain(buf.to_owned())
- .chain(vec![b'>'])
- .collect();
- }
- }
- buf.clear();
+ pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
+ if let Some(features) = Element::read(&mut self.reader).await? {
+ Ok(Some(features.try_into()?))
+ } else {
+ Ok(None)
}
- println!("{:?}", txt);
- let decoded = str::from_utf8(&txt)?.to_owned();
- println!("{:?}", decoded);
- Ok(decoded)
- }
-
- pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
- let node = self.get_node().await?;
- let mut deserializer = Deserializer::from_str(&node);
- let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
- println!("{:?}", features);
- Ok(features.features)
}
pub async fn negotiate(&mut self) -> Result<()> {
@@ -131,98 +61,14 @@ impl<'j> JabberClient<'j> {
println!("loop");
let features = &self.get_features().await?;
println!("{:?}", features);
- match &features[0] {
- StreamFeature::Sasl(sasl) => {
- println!("{:?}", sasl);
- self.sasl(&sasl).await?;
- }
- StreamFeature::Bind => todo!(),
- x => println!("{:?}", x),
- }
+ // match &features[0] {
+ // StreamFeature::Sasl(sasl) => {
+ // println!("{:?}", sasl);
+ // todo!()
+ // }
+ // StreamFeature::Bind => todo!(),
+ // x => println!("{:?}", x),
+ // }
}
}
-
- pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> {
- println!("{:?}", mechanisms);
- let sasl = SASLClient::new(self.jabber.auth.clone());
- let mut offered_mechs: Vec<&Mechname> = Vec::new();
- for mechanism in &mechanisms.mechanisms {
- offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?)
- }
- println!("{:?}", offered_mechs);
- let mut session = sasl.start_suggested(&offered_mechs)?;
- let selected_mechanism = session.get_mechname().as_str().to_owned();
- println!("selected mech: {:?}", selected_mechanism);
- let mut data: Option<Vec<u8>> = None;
- if !session.are_we_first() {
- // if not first mention the mechanism then get challenge data
- // mention mechanism
- let auth = Auth {
- ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
- mechanism: selected_mechanism.clone(),
- sasl_data: Some("=".to_owned()),
- };
- let mut buffer = String::new();
- let ser = Serializer::new(&mut buffer);
- auth.serialize(ser).unwrap();
- self.writer.write_all(buffer.as_bytes());
- // get challenge data
- let node = self.get_node().await?;
- let mut deserializer = Deserializer::from_str(&node);
- let challenge = Challenge::deserialize(&mut deserializer).unwrap();
- println!("challenge: {:?}", challenge);
- data = Some(challenge.sasl_data.as_bytes().to_owned());
- println!("we didn't go first");
- } else {
- // if first, mention mechanism and send data
- let mut sasl_data = Vec::new();
- session.step64(None, &mut sasl_data).unwrap();
- let auth = Auth {
- ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
- mechanism: selected_mechanism.clone(),
- sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
- };
- let mut buffer = String::new();
- let ser = Serializer::new(&mut buffer);
- auth.serialize(ser).unwrap();
- println!("node: {:?}", buffer);
- self.writer.write_all(buffer.as_bytes()).await;
- println!("we went first");
- // get challenge data
- // TODO: check if needed
- // let node = self.get_node().await?;
- // println!("node: {:?}", node);
- // let mut deserializer = Deserializer::from_str(&node);
- // let challenge = Challenge::deserialize(&mut deserializer).unwrap();
- // println!("challenge: {:?}", challenge);
- // data = Some(challenge.sasl_data.as_bytes().to_owned());
- }
-
- // stepping the authentication exchange to completion
- let mut sasl_data = Vec::new();
- while {
- // decide if need to send more data over
- let state = session
- .step64(data.as_deref(), &mut sasl_data)
- .expect("step errored!");
- state.is_running()
- } {
- // While we aren't finished, receive more data from the other party
- let auth = Auth {
- ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
- mechanism: selected_mechanism.clone(),
- sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
- };
- let mut buffer = String::new();
- let ser = Serializer::new(&mut buffer);
- auth.serialize(ser).unwrap();
- self.writer.write_all(buffer.as_bytes());
- let node = self.get_node().await?;
- let mut deserializer = Deserializer::from_str(&node);
- let challenge = Challenge::deserialize(&mut deserializer).unwrap();
- data = Some(challenge.sasl_data.as_bytes().to_owned());
- }
- self.start_stream().await?;
- Ok(())
- }
}
diff --git a/src/client/mod.rs b/src/client/mod.rs
index fe3dd34..d545923 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -15,17 +15,16 @@ pub enum JabberClientType<'j> {
impl<'j> JabberClientType<'j> {
pub async fn ensure_tls(self) -> Result<encrypted::JabberClient<'j>> {
match self {
- Self::Encrypted(mut c) => {
- c.start_stream();
- Ok(c)
- }
+ Self::Encrypted(c) => Ok(c),
Self::Unencrypted(mut c) => {
- c.start_stream().await?;
- let features = c.get_features().await?;
- if features.contains(&StreamFeature::StartTls) {
- Ok(c.starttls().await?)
+ if let Some(features) = c.get_features().await? {
+ if features.contains(&StreamFeature::StartTls) {
+ Ok(c.starttls().await?)
+ } else {
+ Err(JabberError::StartTlsUnavailable)
+ }
} else {
- Err(JabberError::StartTlsUnavailable)
+ Err(JabberError::NoFeatures)
}
}
}
diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs
index d4225d3..ce534c7 100644
--- a/src/client/unencrypted.rs
+++ b/src/client/unencrypted.rs
@@ -1,19 +1,19 @@
use std::str;
use quick_xml::{
- de::Deserializer,
events::{BytesDecl, BytesStart, Event},
name::QName,
Reader, Writer,
};
-use serde::Deserialize;
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
+use crate::element::Element;
+use crate::stanza::stream::StreamFeature;
+use crate::Jabber;
use crate::Result;
-use crate::{error::JabberError, stanza::stream::StreamFeature};
-use crate::{stanza::stream::StreamFeatures, Jabber};
+use crate::{error::JabberError, stanza::stream::Stream};
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TcpStream>>>,
@@ -35,63 +35,30 @@ impl<'j> JabberClient<'j> {
}
pub async fn start_stream(&mut self) -> Result<()> {
+ // client to server
let declaration = BytesDecl::new("1.0", None, None);
- let mut stream_element = BytesStart::new("stream:stream");
- stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
- stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
- stream_element.push_attribute(("version", "1.0"));
- stream_element.push_attribute(("xml:lang", "en"));
- stream_element.push_attribute(("xmlns", "jabber:client"));
- stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
+ let server = &self.jabber.server.to_owned().try_into()?;
+ let stream_element =
+ Stream::new_client(&self.jabber.jid, server, None, Some("en".to_string()));
self.writer
.write_event_async(Event::Decl(declaration))
- .await;
- self.writer
- .write_event_async(Event::Start(stream_element))
- .await
- .unwrap();
+ .await?;
+ let stream_element: Element<'_> = stream_element.into();
+ stream_element.write_start(&mut self.writer).await?;
+ // server to client
let mut buf = Vec::new();
- loop {
- match self.reader.read_event_into_async(&mut buf).await.unwrap() {
- Event::Start(e) => {
- println!("{:?}", e);
- break;
- }
- Event::Decl(e) => println!("decl: {:?}", e),
- _ => return Err(JabberError::BadStream),
- }
- }
+ self.reader.read_event_into_async(&mut buf).await?;
+ let _stream_response = Element::read_start(&mut self.reader).await?;
Ok(())
}
- pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
- let mut buf = Vec::new();
- let mut txt = Vec::new();
- let mut loop_end = false;
- while !loop_end {
- match self.reader.read_event_into_async(&mut buf).await.unwrap() {
- Event::End(e) => {
- if e.name() == QName(b"stream:features") {
- loop_end = true;
- }
- }
- _ => (),
- }
- txt.push(b'<');
- txt = txt
- .into_iter()
- .chain(buf.to_owned())
- .chain(vec![b'>'])
- .collect();
- buf.clear();
+ pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
+ if let Some(features) = Element::read(&mut self.reader).await? {
+ println!("{:?}", features);
+ Ok(Some(features.try_into()?))
+ } else {
+ Ok(None)
}
- println!("{:?}", txt);
- let decoded = str::from_utf8(&txt).unwrap();
- println!("decoded: {:?}", decoded);
- let mut deserializer = Deserializer::from_str(decoded);
- let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
- println!("{:?}", features);
- Ok(features.features)
}
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {
@@ -115,8 +82,9 @@ impl<'j> JabberClient<'j> {
.connect(&self.jabber.server, stream)
.await
{
- let (read, writer) = tokio::io::split(tlsstream);
+ let (read, write) = tokio::io::split(tlsstream);
let reader = Reader::from_reader(BufReader::new(read));
+ let writer = Writer::new(write);
let mut client =
super::encrypted::JabberClient::new(reader, writer, self.jabber);
client.start_stream().await?;
diff --git a/src/element.rs b/src/element.rs
new file mode 100644
index 0000000..21b1a3e
--- /dev/null
+++ b/src/element.rs
@@ -0,0 +1,108 @@
+use async_recursion::async_recursion;
+use quick_xml::events::Event;
+use quick_xml::{Reader, Writer};
+use tokio::io::{AsyncBufRead, AsyncWrite};
+
+use crate::Result;
+
+#[derive(Debug)]
+pub struct Element<'e> {
+ pub event: Event<'e>,
+ pub content: Option<Vec<Element<'e>>>,
+}
+
+// TODO: make method
+#[async_recursion]
+pub async fn write<'e: 'async_recursion, W: AsyncWrite + Unpin + Send>(
+ element: Element<'e>,
+ writer: &mut Writer<W>,
+) -> Result<()> {
+ match element.event {
+ Event::Start(e) => {
+ writer.write_event_async(Event::Start(e.clone())).await?;
+ if let Some(content) = element.content {
+ for e in content {
+ write(e, writer).await?;
+ }
+ }
+ writer.write_event_async(Event::End(e.to_end())).await?;
+ return Ok(());
+ }
+ e => Ok(writer.write_event_async(e).await?),
+ }
+}
+
+impl<'e> Element<'e> {
+ pub async fn write_start<W: AsyncWrite + Unpin + Send>(
+ &self,
+ writer: &mut Writer<W>,
+ ) -> Result<()> {
+ match self.event.as_ref() {
+ Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?),
+ e => Err(ElementError::NotAStart(e.clone().into_owned()).into()),
+ }
+ }
+
+ pub async fn write_end<W: AsyncWrite + Unpin + Send>(
+ &self,
+ writer: &mut Writer<W>,
+ ) -> Result<()> {
+ match self.event.as_ref() {
+ Event::Start(e) => Ok(writer
+ .write_event_async(Event::End(e.clone().to_end()))
+ .await?),
+ e => Err(ElementError::NotAStart(e.clone().into_owned()).into()),
+ }
+ }
+
+ #[async_recursion]
+ pub async fn read<R: AsyncBufRead + Unpin + Send>(
+ reader: &mut Reader<R>,
+ ) -> Result<Option<Self>> {
+ let mut buf = Vec::new();
+ let event = reader.read_event_into_async(&mut buf).await?;
+ match event {
+ Event::Start(e) => {
+ let mut content_vec = Vec::new();
+ while let Some(sub_element) = Element::read(reader).await? {
+ content_vec.push(sub_element)
+ }
+ let mut content = None;
+ if !content_vec.is_empty() {
+ content = Some(content_vec)
+ }
+ Ok(Some(Self {
+ event: Event::Start(e.into_owned()),
+ content,
+ }))
+ }
+ Event::End(_) => Ok(None),
+ e => Ok(Some(Self {
+ event: e.into_owned(),
+ content: None,
+ })),
+ }
+ }
+
+ #[async_recursion]
+ pub async fn read_start<R: AsyncBufRead + Unpin + Send>(
+ reader: &mut Reader<R>,
+ ) -> Result<Self> {
+ let mut buf = Vec::new();
+ let event = reader.read_event_into_async(&mut buf).await?;
+ match event {
+ Event::Start(e) => {
+ return Ok(Self {
+ event: Event::Start(e.into_owned()),
+ content: None,
+ })
+ }
+ e => Err(ElementError::NotAStart(e.into_owned()).into()),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum ElementError<'e> {
+ NotAStart(Event<'e>),
+}
diff --git a/src/error.rs b/src/error.rs
index 20ebc3e..37be7fa 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,7 +1,13 @@
use std::str::Utf8Error;
+use quick_xml::events::attributes::AttrError;
use rsasl::mechname::MechanismNameError;
+use crate::{
+ element::{self, ElementError},
+ jid::ParseError,
+};
+
#[derive(Debug)]
pub enum JabberError {
Connection,
@@ -9,8 +15,13 @@ pub enum JabberError {
StartTlsUnavailable,
TlsNegotiation,
Utf8Decode,
+ NoFeatures,
+ UnknownNamespace,
+ ParseError,
XML(quick_xml::Error),
SASL(SASLError),
+ Element(ElementError<'static>),
+ JID(ParseError),
}
#[derive(Debug)]
@@ -32,7 +43,7 @@ impl From<MechanismNameError> for JabberError {
}
impl From<Utf8Error> for JabberError {
- fn from(e: Utf8Error) -> Self {
+ fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
@@ -42,3 +53,21 @@ impl From<quick_xml::Error> for JabberError {
Self::XML(e)
}
}
+
+impl From<element::ElementError<'static>> for JabberError {
+ fn from(e: element::ElementError<'static>) -> Self {
+ Self::Element(e)
+ }
+}
+
+impl From<AttrError> for JabberError {
+ fn from(e: AttrError) -> Self {
+ Self::XML(e.into())
+ }
+}
+
+impl From<ParseError> for JabberError {
+ fn from(e: ParseError) -> Self {
+ Self::JID(e)
+ }
+}
diff --git a/src/jabber.rs b/src/jabber.rs
index a1b2a2f..a48751c 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -117,11 +117,12 @@ impl<'j> Jabber<'j> {
.connect(&self.server, socket)
.await
{
- let (read, writer) = tokio::io::split(stream);
+ let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
- return Ok(JabberClientType::Encrypted(
- client::encrypted::JabberClient::new(reader, writer, self),
- ));
+ let writer = Writer::new(write);
+ let mut client = client::encrypted::JabberClient::new(reader, writer, self);
+ client.start_stream().await?;
+ return Ok(JabberClientType::Encrypted(client));
}
}
false => {
@@ -129,9 +130,10 @@ impl<'j> Jabber<'j> {
let (read, write) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
let writer = Writer::new(write);
- return Ok(JabberClientType::Unencrypted(
- client::unencrypted::JabberClient::new(reader, writer, self),
- ));
+ let mut client =
+ client::unencrypted::JabberClient::new(reader, writer, self);
+ client.start_stream().await?;
+ return Ok(JabberClientType::Unencrypted(client));
}
}
}
diff --git a/src/jid/mod.rs b/src/jid/mod.rs
index b2a03ea..e13fed7 100644
--- a/src/jid/mod.rs
+++ b/src/jid/mod.rs
@@ -1,6 +1,6 @@
use std::str::FromStr;
-#[derive(PartialEq, Debug)]
+#[derive(PartialEq, Debug, Clone)]
pub struct JID {
// TODO: validate localpart (length, char]
pub localpart: Option<String>,
diff --git a/src/lib.rs b/src/lib.rs
index d27f0ba..89e69c9 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -2,6 +2,7 @@
// TODO: logging (dropped errors)
pub mod client;
+pub mod element;
pub mod error;
pub mod jabber;
pub mod jid;
diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs
index 4eaa4c2..02ea277 100644
--- a/src/stanza/mod.rs
+++ b/src/stanza/mod.rs
@@ -1,2 +1,6 @@
+// use quick_xml::events::BytesDecl;
+
pub mod sasl;
pub mod stream;
+
+// const DECLARATION: BytesDecl<'_> = BytesDecl::new("1.0", None, None);
diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs
index c0e41ab..1f77ffa 100644
--- a/src/stanza/sasl.rs
+++ b/src/stanza/sasl.rs
@@ -1,32 +1,8 @@
-use serde::{Deserialize, Serialize};
-
-#[derive(Deserialize, PartialEq, Debug)]
-pub struct Mechanisms {
- #[serde(rename = "$value")]
- pub mechanisms: Vec<Mechanism>,
-}
-
-#[derive(Deserialize, PartialEq, Debug)]
-pub struct Mechanism {
- #[serde(rename = "$text")]
- pub mechanism: String,
-}
-
-#[derive(Serialize, Debug)]
-#[serde(rename = "auth")]
pub struct Auth {
- #[serde(rename = "@xmlns")]
- pub ns: String,
- #[serde(rename = "@mechanism")]
pub mechanism: String,
- #[serde(rename = "$text")]
pub sasl_data: Option<String>,
}
-#[derive(Deserialize, Debug)]
pub struct Challenge {
- #[serde(rename = "@xmlns")]
- pub ns: String,
- #[serde(rename = "$text")]
pub sasl_data: String,
}
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index 4c0addd..f0fb6a1 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -1,41 +1,182 @@
-use serde::{Deserialize, Serialize};
+use std::str;
-use super::sasl::Mechanisms;
+use quick_xml::{
+ events::{BytesStart, Event},
+ name::QName,
+};
-#[derive(Serialize, Deserialize)]
-#[serde(rename = "stream:stream")]
-struct Stream {
- #[serde(rename = "@from")]
- from: Option<String>,
- #[serde(rename = "@id")]
+use crate::{element::Element, JabberError, Result, JID};
+
+const XMLNS_STREAM: &str = "http://etherx.jabber.org/streams";
+const VERSION: &str = "1.0";
+
+enum XMLNS {
+ Client,
+ Server,
+}
+
+impl From<XMLNS> for &str {
+ fn from(xmlns: XMLNS) -> Self {
+ match xmlns {
+ XMLNS::Client => return "jabber:client",
+ XMLNS::Server => return "jabber:server",
+ }
+ }
+}
+
+impl TryInto<XMLNS> for &str {
+ type Error = JabberError;
+
+ fn try_into(self) -> Result<XMLNS> {
+ match self {
+ "jabber:client" => Ok(XMLNS::Client),
+ "jabber:server" => Ok(XMLNS::Server),
+ _ => Err(JabberError::UnknownNamespace),
+ }
+ }
+}
+
+pub struct Stream {
+ from: Option<JID>,
id: Option<String>,
- #[serde(rename = "@to")]
- to: Option<String>,
- #[serde(rename = "@version")]
- version: Option<f32>,
- #[serde(rename = "@xml:lang")]
+ to: Option<JID>,
+ version: Option<String>,
lang: Option<String>,
- #[serde(rename = "@xmlns")]
- namespace: Option<String>,
- #[serde(rename = "@xmlns:stream")]
- stream_namespace: Option<String>,
+ _ns: XMLNS,
}
-#[derive(Deserialize, Debug)]
-#[serde(rename = "stream:features")]
-pub struct StreamFeatures {
- #[serde(rename = "$value")]
- pub features: Vec<StreamFeature>,
+impl Stream {
+ pub fn new_client(from: &JID, to: &JID, id: Option<String>, lang: Option<String>) -> Self {
+ Self {
+ from: Some(from.clone()),
+ id,
+ to: Some(to.clone()),
+ version: Some(VERSION.to_owned()),
+ lang,
+ _ns: XMLNS::Client,
+ }
+ }
+
+ fn build(&self) -> BytesStart {
+ let mut start = BytesStart::new("stream:stream");
+ if let Some(from) = &self.from {
+ start.push_attribute(("from", from.to_string().as_str()));
+ }
+ if let Some(id) = &self.id {
+ start.push_attribute(("id", id.as_str()));
+ }
+ if let Some(to) = &self.to {
+ start.push_attribute(("to", to.to_string().as_str()));
+ }
+ if let Some(version) = &self.version {
+ start.push_attribute(("version", version.to_string().as_str()));
+ }
+ if let Some(lang) = &self.lang {
+ start.push_attribute(("xml:lang", lang.as_str()));
+ }
+ start.push_attribute(("xmlns", XMLNS::Client.into()));
+ start.push_attribute(("xmlns:stream", XMLNS_STREAM));
+ start
+ }
+}
+
+impl<'e> Into<Element<'e>> for Stream {
+ fn into(self) -> Element<'e> {
+ Element {
+ event: Event::Start(self.build().to_owned()),
+ content: None,
+ }
+ }
}
-#[derive(Deserialize, PartialEq, Debug)]
+impl<'e> TryFrom<Element<'e>> for Stream {
+ type Error = JabberError;
+
+ fn try_from(value: Element<'e>) -> Result<Stream> {
+ let (mut from, mut id, mut to, mut version, mut lang, mut ns) =
+ (None, None, None, None, None, XMLNS::Client);
+ if let Event::Start(e) = value.event.as_ref() {
+ for attribute in e.attributes() {
+ let attribute = attribute?;
+ match attribute.key {
+ QName(b"from") => {
+ from = Some(str::from_utf8(&attribute.value)?.to_string().try_into()?);
+ }
+ QName(b"id") => {
+ id = Some(str::from_utf8(&attribute.value)?.to_owned());
+ }
+ QName(b"to") => {
+ to = Some(str::from_utf8(&attribute.value)?.to_string().try_into()?);
+ }
+ QName(b"version") => {
+ version = Some(str::from_utf8(&attribute.value)?.to_owned());
+ }
+ QName(b"lang") => {
+ lang = Some(str::from_utf8(&attribute.value)?.to_owned());
+ }
+ QName(b"xmlns") => {
+ ns = str::from_utf8(&attribute.value)?.try_into()?;
+ }
+ _ => {
+ println!("unknown attribute: {:?}", attribute)
+ }
+ }
+ }
+ Ok(Stream {
+ from,
+ id,
+ to,
+ version,
+ lang,
+ _ns: ns,
+ })
+ } else {
+ Err(JabberError::ParseError)
+ }
+ }
+}
+
+#[derive(PartialEq, Debug)]
pub enum StreamFeature {
- #[serde(rename = "starttls")]
StartTls,
- // TODO: other stream features
- #[serde(rename = "mechanisms")]
- Sasl(Mechanisms),
+ Sasl(Vec<String>),
Bind,
- #[serde(other)]
Unknown,
}
+
+impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {
+ type Error = JabberError;
+
+ fn try_from(features_element: Element) -> Result<Self> {
+ let mut features = Vec::new();
+ if let Some(content) = features_element.content {
+ for feature_element in content {
+ match feature_element.event {
+ Event::Start(e) => match e.name() {
+ QName(b"starttls") => features.push(StreamFeature::StartTls),
+ QName(b"mechanisms") => {
+ let mut mechanisms = Vec::new();
+ if let Some(content) = feature_element.content {
+ for mechanism_element in content {
+ if let Some(content) = mechanism_element.content {
+ for mechanism_text in content {
+ match mechanism_text.event {
+ Event::Text(e) => mechanisms
+ .push(str::from_utf8(e.as_ref())?.to_owned()),
+ _ => {}
+ }
+ }
+ }
+ }
+ }
+ features.push(StreamFeature::Sasl(mechanisms))
+ }
+ _ => {}
+ },
+ _ => features.push(StreamFeature::Unknown),
+ }
+ }
+ }
+ Ok(features)
+ }
+}