summaryrefslogtreecommitdiffstats
path: root/src/client
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/client/encrypted.rs130
-rw-r--r--src/client/mod.rs11
-rw-r--r--src/client/unencrypted.rs8
3 files changed, 118 insertions, 31 deletions
diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs
index 898dc23..e8b7271 100644
--- a/src/client/encrypted.rs
+++ b/src/client/encrypted.rs
@@ -1,13 +1,23 @@
+use std::str;
+
use quick_xml::{
events::{BytesDecl, Event},
+ name::QName,
Reader, Writer,
};
+use rsasl::prelude::{Mechname, SASLClient};
use tokio::io::{BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
-use crate::stanza::stream::{Stream, StreamFeature};
-use crate::stanza::Element;
+use crate::stanza::{
+ sasl::{Auth, Response},
+ stream::{Stream, StreamFeature},
+};
+use crate::stanza::{
+ sasl::{Challenge, Success},
+ Element,
+};
use crate::Jabber;
use crate::Result;
@@ -48,27 +58,111 @@ impl<'j> JabberClient<'j> {
Ok(())
}
- pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
- if let Some(features) = Element::read(&mut self.reader).await? {
- Ok(Some(features.try_into()?))
- } else {
- Ok(None)
- }
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ Element::read(&mut self.reader).await?.try_into()
}
pub async fn negotiate(&mut self) -> Result<()> {
loop {
println!("loop");
- let features = &self.get_features().await?;
- println!("{:?}", features);
- // match &features[0] {
- // StreamFeature::Sasl(sasl) => {
- // println!("{:?}", sasl);
- // todo!()
- // }
- // StreamFeature::Bind => todo!(),
- // x => println!("{:?}", x),
- // }
+ let features = self.get_features().await?;
+ println!("features: {:?}", features);
+ match &features[0] {
+ StreamFeature::Sasl(sasl) => {
+ println!("sasl?");
+ self.sasl(&sasl).await?;
+ }
+ StreamFeature::Bind => todo!(),
+ x => println!("{:?}", x),
+ }
+ }
+ }
+
+ pub async fn sasl(&mut self, mechanisms: &Vec<String>) -> Result<()> {
+ println!("{:?}", mechanisms);
+ let sasl = SASLClient::new(self.jabber.auth.clone());
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in mechanisms {
+ offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
+ println!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ println!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ mechanism: selected_mechanism.as_str(),
+ sasl_data: "=",
+ };
+ Into::<Element>::into(auth).write(&mut self.writer).await?;
+ // get challenge data
+ let challenge = &Element::read(&mut self.reader).await?;
+ let challenge: Challenge = challenge.try_into()?;
+ println!("challenge: {:?}", challenge);
+ data = Some(challenge.sasl_data.to_owned());
+ println!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ mechanism: selected_mechanism.as_str(),
+ sasl_data: str::from_utf8(&sasl_data)?,
+ };
+ println!("{:?}", auth);
+ Into::<Element>::into(auth).write(&mut self.writer).await?;
+
+ let server_response = Element::read(&mut self.reader).await?;
+ println!("server_response: {:#?}", server_response);
+ match TryInto::<Challenge>::try_into(&server_response) {
+ Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
+ Err(_) => {
+ let success = TryInto::<Success>::try_into(&server_response)?;
+ if let Some(sasl_data) = success.sasl_data {
+ data = Some(sasl_data.to_owned())
+ }
+ }
+ }
+ println!("we went first");
+ }
+
+ // stepping the authentication exchange to completion
+ if data != None {
+ println!("data: {:?}", data);
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let response = Response {
+ sasl_data: str::from_utf8(&sasl_data)?,
+ };
+ println!("response: {:?}", response);
+ Into::<Element>::into(response)
+ .write(&mut self.writer)
+ .await?;
+
+ let server_response = Element::read(&mut self.reader).await?;
+ println!("server_response: {:?}", server_response);
+ match TryInto::<Challenge>::try_into(&server_response) {
+ Ok(challenge) => data = Some(challenge.sasl_data.to_owned()),
+ Err(_) => {
+ let success = TryInto::<Success>::try_into(&server_response)?;
+ if let Some(sasl_data) = success.sasl_data {
+ data = Some(sasl_data.to_owned())
+ }
+ }
+ }
+ }
+ }
+ self.start_stream().await?;
+ Ok(())
}
}
diff --git a/src/client/mod.rs b/src/client/mod.rs
index d545923..280e0a1 100644
--- a/src/client/mod.rs
+++ b/src/client/mod.rs
@@ -17,14 +17,11 @@ impl<'j> JabberClientType<'j> {
match self {
Self::Encrypted(c) => Ok(c),
Self::Unencrypted(mut c) => {
- if let Some(features) = c.get_features().await? {
- if features.contains(&StreamFeature::StartTls) {
- Ok(c.starttls().await?)
- } else {
- Err(JabberError::StartTlsUnavailable)
- }
+ let features = c.get_features().await?;
+ if features.contains(&StreamFeature::StartTls) {
+ Ok(c.starttls().await?)
} else {
- Err(JabberError::NoFeatures)
+ Err(JabberError::StartTlsUnavailable)
}
}
}
diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs
index dcd10c6..27b0a5f 100644
--- a/src/client/unencrypted.rs
+++ b/src/client/unencrypted.rs
@@ -50,12 +50,8 @@ impl<'j> JabberClient<'j> {
Ok(())
}
- pub async fn get_features(&mut self) -> Result<Option<Vec<StreamFeature>>> {
- if let Some(features) = Element::read(&mut self.reader).await? {
- Ok(Some(features.try_into()?))
- } else {
- Ok(None)
- }
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ Element::read(&mut self.reader).await?.try_into()
}
pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {