summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@blos.sm>2023-07-12 21:11:20 +0100
committerLibravatar cel 🌸 <cel@blos.sm>2023-07-12 21:11:20 +0100
commit322b2a3b46348ec1c5acbc538de93310c9030b96 (patch)
treee447920e2414c4d3d99ce021785f0fe8103d378a
parentc9683935f1e94a701be3e6efe0634dbc63c861de (diff)
downloadluz-322b2a3b46348ec1c5acbc538de93310c9030b96.tar.gz
luz-322b2a3b46348ec1c5acbc538de93310c9030b96.tar.bz2
luz-322b2a3b46348ec1c5acbc538de93310c9030b96.zip
reimplement sasl (with SCRAM!)
-rw-r--r--Cargo.toml2
-rw-r--r--TODO.md2
-rw-r--r--src/client/encrypted.rs130
-rw-r--r--src/client/mod.rs11
-rw-r--r--src/client/unencrypted.rs8
-rw-r--r--src/error.rs13
-rw-r--r--src/jabber.rs2
-rw-r--r--src/stanza/mod.rs74
-rw-r--r--src/stanza/sasl.rs163
-rw-r--r--src/stanza/stream.rs20
10 files changed, 356 insertions, 69 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 153f648..eb89659 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,7 +11,7 @@ async-recursion = "1.0.4"
async-trait = "0.1.68"
quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio"] }
# TODO: remove unneeded features
-rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] }
+rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
trust-dns-resolver = "0.22.0"
diff --git a/TODO.md b/TODO.md
index 068be75..22d656a 100644
--- a/TODO.md
+++ b/TODO.md
@@ -7,3 +7,5 @@
[ ] remove unwraps
[ ] proper error types
[ ] stream error type
+[ ] change stanzas from owned to borrowed types with lifetimes
+[ ] Into<Element> trait with event() and content() functions
diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs
index 898dc23..e8b7271 100644
--- a/src/client/encrypted.rs
+++ b/src/client/encrypted.rs
@@ -1,13 +1,23 @@
+use std::str;
+
use quick_xml::{
events::{BytesDecl, Event},
+ name::QName,
Reader, Writer,
};
+use rsasl::prelude::{Mechname, SASLClient};
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
-use crate::stanza::stream::{Stream, StreamFeature};
-use crate::stanza::Element;
+use crate::stanza::{
+ sasl::{Auth, Response},
+ stream::{Stream, StreamFeature},
+};
+use crate::stanza::{
+ sasl::{Challenge, Success},
+ Element,
+};
use crate::Jabber;
use crate::Result;
@@ -48,27 +58,111 @@ impl<'j> JabberClient<'j> {
Ok(())
}
- pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
- if let Some(features) = Element::read(&mut self.reader).await? {
- Ok(Some(features.try_into()?))
- } else {
- Ok(None)
- }
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ Element::read(&mut self.reader).await?.try_into()
}
pub async fn negotiate(&mut self) -> Result<()> {
loop {
println!("loop");
- let features = &self.get_features().await?;
- println!("{:?}", features);
- // match &features[0] {
- // StreamFeature::Sasl(sasl) => {
- // println!("{:?}", sasl);
- // todo!()
- // }
- // StreamFeature::Bind => todo!(),
- // x => println!("{:?}", x),
- // }
+ let features = self.get_features().await?;
+ println!("features: {:?}", features);
+ match &features[0] {
+ StreamFeature::Sasl(sasl) => {
+ println!("sasl?");
+ self.sasl(&sasl).await?;
+ }
+ StreamFeature::Bind => todo!(),
+ x => println!("{:?}", x),
+ }
+ }
+ }
+
+ pub async fn sasl(&mut self, mechanisms: &Vec<String>) -> Result<()> {
+ println!("{:?}", mechanisms);
+ let sasl = SASLClient::new(self.jabber.auth.clone());
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in mechanisms {
+ offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
+ println!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ println!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ mechanism: selected_mechanism.as_str(),
+ sasl_data: "=",
+ };
+ Into::<Element>::into(auth).write(&mut self.writer).await?;
+ // get challenge data
+ let challenge = &Element::read(&mut self.reader).await?;
+ let challenge: Challenge = challenge.try_into()?;
+ println!("challenge: {:?}", challenge);
+ data = Some(challenge.sasl_data.to_owned());
+ println!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ mechanism: selected_mechanism.as_str(),
+ sasl_data: str::from_utf8(&sasl_data)?,
+ };
+ println!("{:?}", auth);
+ Into::<Element>::into(auth).write(&mut self.writer).await?;
+
+ let server_response = Element::read(&mut self.reader).await?;
+ println!("server_response: {:#?}", server_response);
+ match TryInto::<Challenge>::try_into(&server_response) {
+ Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
+ Err(_) => {
+ let success = TryInto::<Success>::try_into(&server_response)?;
+ if let Some(sasl_data) = success.sasl_data {
+ data = Some(sasl_data.to_owned())
+ }
+ }
+ }
+ println!("we went first");
+ }
+
+ // stepping the authentication exchange to completion
+ if data != None {
+ println!("data: {:?}", data);
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let response = Response {
+ sasl_data: str::from_utf8(&sasl_data)?,
+ };
+ println!("response: {:?}", response);
+ Into::<Element>::into(response)
+ .write(&mut self.writer)
+ .await?;
+
+ let server_response = Element::read(&mut self.reader).await?;
+ println!("server_response: {:?}", server_response);
+ match TryInto::<Challenge>::try_into(&server_response) {
+ Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
+ Err(_) => {
+ let success = TryInto::<Success>::try_into(&server_response)?;
+ if let Some(sasl_data) = success.sasl_data {
+ data = Some(sasl_data.to_owned())
+ }
+ }
+ }
+ }
+ }
+ self.start_stream().await?;
+ Ok(())
}
}
diff --git a/src/client/mod.rs b/src/client/mod.rs
index d545923..280e0a1 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -17,14 +17,11 @@ impl<'j> JabberClientType<'j> {
match self {
Self::Encrypted(c) => Ok(c),
Self::Unencrypted(mut c) => {
- if let Some(features) = c.get_features().await? {
- if features.contains(&StreamFeature::StartTls) {
- Ok(c.starttls().await?)
- } else {
- Err(JabberError::StartTlsUnavailable)
- }
+ let features = c.get_features().await?;
+ if features.contains(&StreamFeature::StartTls) {
+ Ok(c.starttls().await?)
} else {
- Err(JabberError::NoFeatures)
+ Err(JabberError::StartTlsUnavailable)
}
}
}
diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs
index dcd10c6..27b0a5f 100644
--- a/src/client/unencrypted.rs
+++ b/src/client/unencrypted.rs
@@ -50,12 +50,8 @@ impl<'j> JabberClient<'j> {
Ok(())
}
- pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
- if let Some(features) = Element::read(&mut self.reader).await? {
- Ok(Some(features.try_into()?))
- } else {
- Ok(None)
- }
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ Element::read(&mut self.reader).await?.try_into()
}
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {
diff --git a/src/error.rs b/src/error.rs
index 7f704e5..17bfbef 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -18,6 +18,7 @@ pub enum JabberError {
NoFeatures,
UnknownNamespace,
ParseError,
+ UnexpectedEnd,
XML(quick_xml::Error),
SASL(SASLError),
Element(ElementError<'static>),
@@ -28,6 +29,8 @@ pub enum JabberError {
pub enum SASLError {
SASL(rsasl::prelude::SASLError),
MechanismName(MechanismNameError),
+ NoChallenge,
+ NoSuccess,
}
impl From<rsasl::prelude::SASLError> for JabberError {
@@ -37,8 +40,14 @@ impl From<rsasl::prelude::SASLError> for JabberError {
}
impl From<MechanismNameError> for JabberError {
- fn from(value: MechanismNameError) -> Self {
- Self::SASL(SASLError::MechanismName(value))
+ fn from(e: MechanismNameError) -> Self {
+ Self::SASL(SASLError::MechanismName(e))
+ }
+}
+
+impl From<SASLError> for JabberError {
+ fn from(e: SASLError) -> Self {
+ Self::SASL(e)
}
}
diff --git a/src/jabber.rs b/src/jabber.rs
index a48751c..1a7eddb 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -24,7 +24,7 @@ pub struct Jabber<'j> {
impl<'j> Jabber<'j> {
pub fn new(jid: JID, password: String) -> Result<Self> {
let server = jid.domainpart.clone();
- let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?;
+ let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
println!("auth: {:?}", auth);
Ok(Self {
jid,
diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs
index 16f3bdd..c29b1a2 100644
--- a/src/stanza/mod.rs
+++ b/src/stanza/mod.rs
@@ -9,12 +9,12 @@ use quick_xml::events::Event;
use quick_xml::{Reader, Writer};
use tokio::io::{AsyncBufRead, AsyncWrite};
-use crate::Result;
+use crate::JabberError;
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct Element<'e> {
pub event: Event<'e>,
- pub content: Option<Vec<Element<'e>>>,
+ pub children: Option<Vec<Element<'e>>>,
}
impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
@@ -23,7 +23,7 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
writer: &'life0 mut Writer<W>,
) -> ::core::pin::Pin<
Box<
- dyn ::core::future::Future<Output = Result<()>>
+ dyn ::core::future::Future<Output = Result<(), JabberError>>
+ 'async_recursion
+ ::core::marker::Send,
>,
@@ -36,9 +36,9 @@ impl<'e: 'async_recursion, 'async_recursion> Element<'e> {
match &self.event {
Event::Start(e) => {
writer.write_event_async(Event::Start(e.clone())).await?;
- if let Some(content) = &self.content {
- for _e in content {
- self.write(writer).await?;
+ if let Some(children) = &self.children {
+ for e in children {
+ e.write(writer).await?;
}
}
writer.write_event_async(Event::End(e.to_end())).await?;
@@ -54,7 +54,7 @@ impl<'e> Element<'e> {
pub async fn write_start<W: AsyncWrite + Unpin + Send>(
&self,
writer: &mut Writer<W>,
- ) -> Result<()> {
+ ) -> Result<(), JabberError> {
match self.event.as_ref() {
Event::Start(e) => Ok(writer.write_event_async(Event::Start(e.clone())).await?),
e => Err(ElementError::NotAStart(e.clone().into_owned()).into()),
@@ -64,7 +64,7 @@ impl<'e> Element<'e> {
pub async fn write_end<W: AsyncWrite + Unpin + Send>(
&self,
writer: &mut Writer<W>,
- ) -> Result<()> {
+ ) -> Result<(), JabberError> {
match self.event.as_ref() {
Event::Start(e) => Ok(writer
.write_event_async(Event::End(e.clone().to_end()))
@@ -76,28 +76,38 @@ impl<'e> Element<'e> {
#[async_recursion]
pub async fn read<R: AsyncBufRead + Unpin + Send>(
reader: &mut Reader<R>,
- ) -> Result<Option<Self>> {
+ ) -> Result<Self, JabberError> {
+ let element = Self::read_recursive(reader)
+ .await?
+ .ok_or(JabberError::UnexpectedEnd);
+ element
+ }
+
+ #[async_recursion]
+ async fn read_recursive<R: AsyncBufRead + Unpin + Send>(
+ reader: &mut Reader<R>,
+ ) -> Result<Option<Self>, JabberError> {
let mut buf = Vec::new();
let event = reader.read_event_into_async(&mut buf).await?;
match event {
Event::Start(e) => {
- let mut content_vec = Vec::new();
- while let Some(sub_element) = Element::read(reader).await? {
- content_vec.push(sub_element)
+ let mut children_vec = Vec::new();
+ while let Some(sub_element) = Element::read_recursive(reader).await? {
+ children_vec.push(sub_element)
}
- let mut content = None;
- if !content_vec.is_empty() {
- content = Some(content_vec)
+ let mut children = None;
+ if !children_vec.is_empty() {
+ children = Some(children_vec)
}
Ok(Some(Self {
event: Event::Start(e.into_owned()),
- content,
+ children,
}))
}
Event::End(_) => Ok(None),
e => Ok(Some(Self {
event: e.into_owned(),
- content: None,
+ children: None,
})),
}
}
@@ -105,14 +115,14 @@ impl<'e> Element<'e> {
#[async_recursion]
pub async fn read_start<R: AsyncBufRead + Unpin + Send>(
reader: &mut Reader<R>,
- ) -> Result<Self> {
+ ) -> Result<Self, JabberError> {
let mut buf = Vec::new();
let event = reader.read_event_into_async(&mut buf).await?;
match event {
Event::Start(e) => {
return Ok(Self {
event: Event::Start(e.into_owned()),
- content: None,
+ children: None,
})
}
e => Err(ElementError::NotAStart(e.into_owned()).into()),
@@ -120,7 +130,31 @@ impl<'e> Element<'e> {
}
}
+/// if there is only one child in the vec of children, will return that element
+pub fn child<'p, 'e>(element: &'p Element<'e>) -> Result<&'p Element<'e>, ElementError<'static>> {
+ if let Some(children) = &element.children {
+ if children.len() == 1 {
+ return Ok(&children[0]);
+ } else {
+ return Err(ElementError::MultipleChildren);
+ }
+ }
+ Err(ElementError::NoChildren)
+}
+
+/// returns reference to children
+pub fn children<'p, 'e>(
+ element: &'p Element<'e>,
+) -> Result<&'p Vec<Element<'e>>, ElementError<'e>> {
+ if let Some(children) = &element.children {
+ return Ok(children);
+ }
+ Err(ElementError::NoChildren)
+}
+
#[derive(Debug)]
pub enum ElementError<'e> {
NotAStart(Event<'e>),
+ NoChildren,
+ MultipleChildren,
}
diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs
index 1f77ffa..bbf3f41 100644
--- a/src/stanza/sasl.rs
+++ b/src/stanza/sasl.rs
@@ -1,8 +1,163 @@
-pub struct Auth {
- pub mechanism: String,
- pub sasl_data: Option<String>,
+use quick_xml::{
+ events::{BytesStart, BytesText, Event},
+ name::QName,
+};
+
+use crate::error::SASLError;
+use crate::JabberError;
+
+use super::Element;
+
+const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
+
+#[derive(Debug)]
+pub struct Auth<'e> {
+ pub mechanism: &'e str,
+ pub sasl_data: &'e str,
+}
+
+impl<'e> Auth<'e> {
+ fn event(&self) -> Event<'e> {
+ let mut start = BytesStart::new("auth");
+ start.push_attribute(("xmlns", XMLNS));
+ start.push_attribute(("mechanism", self.mechanism));
+ Event::Start(start)
+ }
+
+ fn children(&self) -> Option<Vec<Element<'e>>> {
+ let sasl = BytesText::from_escaped(self.sasl_data);
+ let sasl = Element {
+ event: Event::Text(sasl),
+ children: None,
+ };
+ Some(vec![sasl])
+ }
}
+impl<'e> Into<Element<'e>> for Auth<'e> {
+ fn into(self) -> Element<'e> {
+ Element {
+ event: self.event(),
+ children: self.children(),
+ }
+ }
+}
+
+#[derive(Debug)]
pub struct Challenge {
- pub sasl_data: String,
+ pub sasl_data: Vec<u8>,
+}
+
+impl<'e> TryFrom<&Element<'e>> for Challenge {
+ type Error = JabberError;
+
+ fn try_from(element: &Element<'e>) -> Result<Challenge, Self::Error> {
+ if let Event::Start(start) = &element.event {
+ if start.name() == QName(b"challenge") {
+ let sasl_data: &Element<'_> = super::child(element)?;
+ if let Event::Text(sasl_data) = &sasl_data.event {
+ let s = sasl_data.clone();
+ let s = s.into_inner();
+ let s = s.to_vec();
+ return Ok(Challenge { sasl_data: s });
+ }
+ }
+ }
+ Err(SASLError::NoChallenge.into())
+ }
+}
+
+// impl<'e> TryFrom<Element<'e>> for Challenge {
+// type Error = JabberError;
+
+// fn try_from(element: Element<'e>) -> Result<Challenge, Self::Error> {
+// if let Event::Start(start) = &element.event {
+// if start.name() == QName(b"challenge") {
+// println!("one");
+// if let Some(children) = element.children.as_deref() {
+// if children.len() == 1 {
+// let sasl_data = children.first().unwrap();
+// if let Event::Text(sasl_data) = &sasl_data.event {
+// return Ok(Challenge {
+// sasl_data: sasl_data.clone().into_inner().to_vec(),
+// });
+// } else {
+// return Err(SASLError::NoChallenge.into());
+// }
+// } else {
+// return Err(SASLError::NoChallenge.into());
+// }
+// } else {
+// return Err(SASLError::NoChallenge.into());
+// }
+// }
+// }
+// Err(SASLError::NoChallenge.into())
+// }
+// }
+
+#[derive(Debug)]
+pub struct Response<'e> {
+ pub sasl_data: &'e str,
+}
+
+impl<'e> Response<'e> {
+ fn event(&self) -> Event<'e> {
+ let mut start = BytesStart::new("response");
+ start.push_attribute(("xmlns", XMLNS));
+ Event::Start(start)
+ }
+
+ fn children(&self) -> Option<Vec<Element<'e>>> {
+ let sasl = BytesText::from_escaped(self.sasl_data);
+ let sasl = Element {
+ event: Event::Text(sasl),
+ children: None,
+ };
+ Some(vec![sasl])
+ }
+}
+
+impl<'e> Into<Element<'e>> for Response<'e> {
+ fn into(self) -> Element<'e> {
+ Element {
+ event: self.event(),
+ children: self.children(),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Success {
+ pub sasl_data: Option<Vec<u8>>,
+}
+
+impl<'e> TryFrom<&Element<'e>> for Success {
+ type Error = JabberError;
+
+ fn try_from(element: &Element<'e>) -> Result<Success, Self::Error> {
+ match &element.event {
+ Event::Start(start) => {
+ if start.name() == QName(b"success") {
+ match super::child(element) {
+ Ok(sasl_data) => {
+ if let Event::Text(sasl_data) = &sasl_data.event {
+ return Ok(Success {
+ sasl_data: Some(sasl_data.clone().into_inner().to_vec()),
+ });
+ }
+ }
+ Err(_) => return Ok(Success { sasl_data: None }),
+ };
+ }
+ }
+ Event::Empty(empty) => {
+ if empty.name() == QName(b"success") {
+ return Ok(Success { sasl_data: None });
+ }
+ }
+ _ => {}
+ }
+ Err(SASLError::NoSuccess.into())
+ }
}
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index 32f449d..66741b8 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -58,7 +58,7 @@ impl Stream {
}
}
- fn build(&self) -> BytesStart {
+ fn event(&self) -> Event<'static> {
let mut start = BytesStart::new("stream:stream");
if let Some(from) = &self.from {
start.push_attribute(("from", from.to_string().as_str()));
@@ -80,15 +80,15 @@ impl Stream {
XMLNS::Server => start.push_attribute(("xmlns", XMLNS::Server.into())),
}
start.push_attribute(("xmlns:stream", XMLNS_STREAM));
- start
+ Event::Start(start)
}
}
impl<'e> Into<Element<'e>> for Stream {
fn into(self) -> Element<'e> {
Element {
- event: Event::Start(self.build().to_owned()),
- content: None,
+ event: self.event(),
+ children: None,
}
}
}
@@ -153,17 +153,17 @@ impl<'e> TryFrom<Element<'e>> for Vec<StreamFeature> {
fn try_from(features_element: Element) -> Result<Self> {
let mut features = Vec::new();
- if let Some(content) = features_element.content {
- for feature_element in content {
+ if let Some(children) = features_element.children {
+ for feature_element in children {
match feature_element.event {
Event::Start(e) => match e.name() {
QName(b"starttls") => features.push(StreamFeature::StartTls),
QName(b"mechanisms") => {
let mut mechanisms = Vec::new();
- if let Some(content) = feature_element.content {
- for mechanism_element in content {
- if let Some(content) = mechanism_element.content {
- for mechanism_text in content {
+ if let Some(children) = feature_element.children {
+ for mechanism_element in children {
+ if let Some(children) = mechanism_element.children {
+ for mechanism_text in children {
match mechanism_text.event {
Event::Text(e) => mechanisms
.push(str::from_utf8(e.as_ref())?.to_owned()),