diff options
author | 2024-11-29 02:11:02 +0000 | |
---|---|---|
committer | 2024-11-29 02:11:02 +0000 | |
commit | b6593389069903cc4c85e40611296d8a240f718d (patch) | |
tree | ae4df92ea45cce5e5b904041a925263e8d629274 /src/jabber.rs | |
parent | 2dcbc9e1f4339993dd47b2659770a9cf4855b02d (diff) | |
download | luz-b6593389069903cc4c85e40611296d8a240f718d.tar.gz luz-b6593389069903cc4c85e40611296d8a240f718d.tar.bz2 luz-b6593389069903cc4c85e40611296d8a240f718d.zip |
implement sasl kinda
Diffstat (limited to 'src/jabber.rs')
-rw-r--r-- | src/jabber.rs | 220 |
1 files changed, 211 insertions, 9 deletions
diff --git a/src/jabber.rs b/src/jabber.rs index a56c65c..9e7f9d8 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,26 +1,26 @@ use std::str; use std::sync::Arc; +use async_recursion::async_recursion; use peanuts::element::{FromElement, IntoElement}; use peanuts::{Reader, Writer}; -use rsasl::prelude::SASLConfig; +use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::time::timeout; use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, info, instrument, trace}; use trust_dns_resolver::proto::rr::domain::IntoLabel; use crate::connection::{Tls, Unencrypted}; use crate::error::Error; +use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::starttls::{Proceed, StartTls}; -use crate::stanza::stream::{Features, Stream}; +use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::XML_VERSION; -use crate::Result; use crate::JID; +use crate::{Connection, Result}; -pub struct Jabber<S> -where - S: AsyncRead + AsyncWrite + Unpin, -{ +pub struct Jabber<S> { reader: Reader<ReadHalf<S>>, writer: Writer<WriteHalf<S>>, jid: Option<JID>, @@ -56,7 +56,89 @@ where S: AsyncRead + AsyncWrite + Unpin + Send, Jabber<S>: std::fmt::Debug, { - // pub async fn negotiate(self) -> Result<Jabber<S>> {} + pub async fn sasl( + &mut self, + mechanisms: Mechanisms, + sasl_config: Arc<SASLConfig>, + ) -> Result<()> { + let sasl = SASLClient::new(sasl_config); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in &mechanisms.mechanisms { + offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) + } + debug!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + debug!("selected mech: {:?}", selected_mechanism); + let mut data: Option<Vec<u8>> = None; + + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: "=".to_string(), + }; + self.writer.write_full(&auth).await?; + // get challenge data + let challenge: Challenge = self.reader.read().await?; + debug!("challenge: {:?}", challenge); + data = Some((*challenge).as_bytes().to_vec()); + debug!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: str::from_utf8(&sasl_data)?.to_string(), + }; + debug!("{:?}", auth); + self.writer.write_full(&auth).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + } + debug!("we went first"); + } + + // stepping the authentication exchange to completion + if data != None { + debug!("data: {:?}", data); + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); + debug!("response: {:?}", response); + self.writer.write_full(&response).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), + } + } + } + Ok(()) + } + + pub async fn bind(&mut self) -> Result<()> { + todo!() + } #[instrument] pub async fn start_stream(&mut self) -> Result<()> { @@ -76,6 +158,8 @@ where let decl = self.reader.read_prolog().await?; // receive stream element and validate + let text = str::from_utf8(self.reader.buffer.data()).unwrap(); + debug!("data: {}", text); let stream: Stream = self.reader.read_start().await?; debug!("got stream: {:?}", stream); if let Some(from) = stream.from { @@ -98,6 +182,87 @@ where } impl Jabber<Unencrypted> { + pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> { + self.start_stream().await?; + // TODO: timeout + let features = self.get_features().await?.features; + if let Some(Feature::StartTls(_)) = features + .iter() + .find(|feature| matches!(feature, Feature::StartTls(_s))) + { + let jabber = self.starttls().await?; + let jabber = jabber.negotiate().await?; + return Ok(jabber); + } else { + // TODO: better error + return Err(Error::TlsRequired); + } + } + + #[async_recursion] + pub async fn negotiate_tls_optional(mut self) -> Result<Connection> { + self.start_stream().await?; + // TODO: timeout + let features = self.get_features().await?.features; + if let Some(Feature::StartTls(_)) = features + .iter() + .find(|feature| matches!(feature, Feature::StartTls(_s))) + { + let jabber = self.starttls().await?; + let jabber = jabber.negotiate().await?; + return Ok(Connection::Encrypted(jabber)); + } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + self.auth.clone(), + features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))), + ) { + self.sasl(mechanisms.clone(), sasl_config).await?; + let jabber = self.negotiate_tls_optional().await?; + Ok(jabber) + } else if let Some(Feature::Bind) = features + .iter() + .find(|feature| matches!(feature, Feature::Bind)) + { + self.bind().await?; + Ok(Connection::Unencrypted(self)) + } else { + // TODO: better error + return Err(Error::Negotiation); + } + } +} + +impl Jabber<Tls> { + #[async_recursion] + pub async fn negotiate(mut self) -> Result<Jabber<Tls>> { + self.start_stream().await?; + let features = self.get_features().await?.features; + + if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + self.auth.clone(), + features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))), + ) { + // TODO: avoid clone + self.sasl(mechanisms.clone(), sasl_config).await?; + let jabber = self.negotiate().await?; + Ok(jabber) + } else if let Some(Feature::Bind) = features + .iter() + .find(|feature| matches!(feature, Feature::Bind)) + { + self.bind().await?; + Ok(self) + } else { + // TODO: better error + return Err(Error::Negotiation); + } + } +} + +impl Jabber<Unencrypted> { pub async fn starttls(mut self) -> Result<Jabber<Tls>> { self.writer .write_full(&StartTls { required: false }) @@ -155,10 +320,47 @@ mod tests { #[test(tokio::test)] async fn start_stream() { - let connection = Connection::connect("blos.sm").await.unwrap(); + let connection = Connection::connect("blos.sm", None, None).await.unwrap(); match connection { Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), } } + + #[test(tokio::test)] + async fn sasl() { + let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + .await + .unwrap() + .ensure_tls() + .await + .unwrap(); + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + jabber.start_stream().await.unwrap(); + + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + jabber.reader.read_buf().await.unwrap(); + let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + println!("data: {}", text); + + let features = jabber.get_features().await.unwrap(); + let (sasl_config, feature) = ( + jabber.auth.clone().unwrap(), + features + .features + .iter() + .find(|feature| matches!(feature, Feature::Sasl(_))) + .unwrap(), + ); + match feature { + Feature::StartTls(_start_tls) => todo!(), + Feature::Sasl(mechanisms) => { + jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); + } + Feature::Bind => todo!(), + Feature::Unknown => todo!(), + } + } } |