diff options
author | 2024-12-04 18:18:37 +0000 | |
---|---|---|
committer | 2024-12-04 18:18:37 +0000 | |
commit | 1b91ff690488b65b552c90bd5392b9a300c8c981 (patch) | |
tree | 9c290f69b26eba0393d7bbc05ba29c28ea74a26e /jabber/src/jabber_stream.rs | |
parent | 03764f8cedb3f0a55a61be0f0a59faaa6357a83a (diff) | |
download | luz-1b91ff690488b65b552c90bd5392b9a300c8c981.tar.gz luz-1b91ff690488b65b552c90bd5392b9a300c8c981.tar.bz2 luz-1b91ff690488b65b552c90bd5392b9a300c8c981.zip |
use cargo workspace
Diffstat (limited to 'jabber/src/jabber_stream.rs')
-rw-r--r-- | jabber/src/jabber_stream.rs | 393 |
1 files changed, 393 insertions, 0 deletions
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs new file mode 100644 index 0000000..dd0dcbf --- /dev/null +++ b/jabber/src/jabber_stream.rs @@ -0,0 +1,393 @@ +use std::pin::pin; +use std::str::{self, FromStr}; +use std::sync::Arc; + +use futures::StreamExt; +use jid::JID; +use peanuts::element::{FromContent, IntoElement}; +use peanuts::{Reader, Writer}; +use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; +use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; +use stanza::client::iq::{Iq, IqType, Query}; +use stanza::client::Stanza; +use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; +use stanza::starttls::{Proceed, StartTls}; +use stanza::stream::{Features, Stream}; +use stanza::XML_VERSION; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio_native_tls::native_tls::TlsConnector; +use tracing::{debug, instrument}; + +use crate::connection::{Tls, Unencrypted}; +use crate::error::Error; +use crate::Result; + +// open stream (streams started) +pub struct JabberStream<S> { + reader: Reader<ReadHalf<S>>, + writer: Writer<WriteHalf<S>>, +} + +impl<S: AsyncRead> futures::Stream for JabberStream<S> { + type Item = Result<Stanza>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + pin!(self).reader.poll_next_unpin(cx).map(|content| { + content.map(|content| -> Result<Stanza> { + let stanza = content.map(|content| Stanza::from_content(content))?; + Ok(stanza?) + }) + }) + } +} + +impl<S> JabberStream<S> +where + S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, + JabberStream<S>: std::fmt::Debug, +{ + #[instrument] + pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> { + let sasl = SASLClient::new(sasl_config); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in &mechanisms.mechanisms { + offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) + } + debug!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + debug!("selected mech: {:?}", selected_mechanism); + let mut data: Option<Vec<u8>> = None; + + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: "=".to_string(), + }; + self.writer.write_full(&auth).await?; + // get challenge data + let challenge: Challenge = self.reader.read().await?; + debug!("challenge: {:?}", challenge); + data = Some((*challenge).as_bytes().to_vec()); + debug!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: str::from_utf8(&sasl_data)?.to_string(), + }; + debug!("{:?}", auth); + self.writer.write_full(&auth).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), + } + debug!("we went first"); + } + + // stepping the authentication exchange to completion + if data != None { + debug!("data: {:?}", data); + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); + debug!("response: {:?}", response); + let stdout = tokio::io::stdout(); + let mut writer = Writer::new(stdout); + writer.write_full(&response).await?; + self.writer.write_full(&response).await?; + debug!("response written"); + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), + } + } + } + let writer = self.writer.into_inner(); + let reader = self.reader.into_inner(); + let stream = reader.unsplit(writer); + Ok(stream) + } + + #[instrument] + pub async fn bind(mut self, jid: &mut JID) -> Result<Self> { + let iq_id = nanoid::nanoid!(); + if let Some(resource) = &jid.resourcepart { + let iq = Iq { + from: None, + id: iq_id.clone(), + to: None, + r#type: IqType::Set, + lang: None, + query: Some(Query::Bind(Bind { + r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), + })), + errors: Vec::new(), + }; + self.writer.write_full(&iq).await?; + let result: Iq = self.reader.read().await?; + match result { + Iq { + from: _, + id, + to: _, + r#type: IqType::Result, + lang: _, + query: + Some(Query::Bind(Bind { + r#type: Some(BindType::Jid(FullJidType(new_jid))), + })), + errors: _, + } if id == iq_id => { + *jid = new_jid; + return Ok(self); + } + Iq { + from: _, + id, + to: _, + r#type: IqType::Error, + lang: _, + query: None, + errors, + } if id == iq_id => { + return Err(Error::ClientError( + errors.first().ok_or(Error::MissingError)?.clone(), + )) + } + _ => return Err(Error::UnexpectedElement(result.into_element())), + } + } else { + let iq = Iq { + from: None, + id: iq_id.clone(), + to: None, + r#type: IqType::Set, + lang: None, + query: Some(Query::Bind(Bind { r#type: None })), + errors: Vec::new(), + }; + self.writer.write_full(&iq).await?; + let result: Iq = self.reader.read().await?; + match result { + Iq { + from: _, + id, + to: _, + r#type: IqType::Result, + lang: _, + query: + Some(Query::Bind(Bind { + r#type: Some(BindType::Jid(FullJidType(new_jid))), + })), + errors: _, + } if id == iq_id => { + *jid = new_jid; + return Ok(self); + } + Iq { + from: _, + id, + to: _, + r#type: IqType::Error, + lang: _, + query: None, + errors, + } if id == iq_id => { + return Err(Error::ClientError( + errors.first().ok_or(Error::MissingError)?.clone(), + )) + } + _ => return Err(Error::UnexpectedElement(result.into_element())), + } + } + } + + #[instrument] + pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> { + // client to server + let (reader, writer) = tokio::io::split(connection); + let mut reader = Reader::new(reader); + let mut writer = Writer::new(writer); + + // declaration + writer.write_declaration(XML_VERSION).await?; + + // opening stream element + let stream = Stream::new_client( + None, + JID::from_str(server.as_ref())?, + None, + "en".to_string(), + ); + writer.write_start(&stream).await?; + + // server to client + + // may or may not send a declaration + let _decl = reader.read_prolog().await?; + + // receive stream element and validate + let stream: Stream = reader.read_start().await?; + debug!("got stream: {:?}", stream); + if let Some(from) = stream.from { + *server = from.to_string(); + } + + Ok(Self { reader, writer }) + } + + #[instrument] + pub async fn get_features(mut self) -> Result<(Features, Self)> { + debug!("getting features"); + let features: Features = self.reader.read().await?; + debug!("got features: {:?}", features); + Ok((features, self)) + } + + pub fn into_inner(self) -> S { + self.reader.into_inner().unsplit(self.writer.into_inner()) + } + + pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + self.writer.write(stanza).await?; + Ok(()) + } +} + +impl JabberStream<Unencrypted> { + #[instrument] + pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> { + self.writer + .write_full(&StartTls { required: false }) + .await?; + let proceed: Proceed = self.reader.read().await?; + debug!("got proceed: {:?}", proceed); + let connector = TlsConnector::new().unwrap(); + let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); + if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) + .connect(domain.as_ref(), stream) + .await + { + return Ok(tls_stream); + } else { + return Err(Error::Connection); + } + } +} + +impl std::fmt::Debug for JabberStream<Tls> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"tls") + .finish() + } +} + +impl std::fmt::Debug for JabberStream<Unencrypted> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"unencrypted") + .finish() + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use crate::connection::Connection; + use test_log::test; + use tokio::time::sleep; + + #[test(tokio::test)] + async fn start_stream() { + // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); + // match connection { + // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), + // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), + // } + } + + #[test(tokio::test)] + async fn sasl() { + // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.start_stream().await.unwrap(); + + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.reader.read_buf().await.unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + + // let features = jabber.get_features().await.unwrap(); + // let (sasl_config, feature) = ( + // jabber.auth.clone().unwrap(), + // features + // .features + // .iter() + // .find(|feature| matches!(feature, Feature::Sasl(_))) + // .unwrap(), + // ); + // match feature { + // Feature::StartTls(_start_tls) => todo!(), + // Feature::Sasl(mechanisms) => { + // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); + // } + // Feature::Bind => todo!(), + // Feature::Unknown => todo!(), + // } + } + + #[tokio::test] + async fn negotiate() { + // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap() + // .negotiate() + // .await + // .unwrap(); + // sleep(Duration::from_secs(5)).await + } +} |