From 03764f8cedb3f0a55a61be0f0a59faaa6357a83a Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Wed, 4 Dec 2024 17:40:56 +0000 Subject: rename jabber to jabber_stream --- src/jabber.rs | 394 --------------------------------------------------- src/jabber_stream.rs | 394 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 4 +- 3 files changed, 396 insertions(+), 396 deletions(-) delete mode 100644 src/jabber.rs create mode 100644 src/jabber_stream.rs diff --git a/src/jabber.rs b/src/jabber.rs deleted file mode 100644 index 8ee45b5..0000000 --- a/src/jabber.rs +++ /dev/null @@ -1,394 +0,0 @@ -use std::pin::pin; -use std::str::{self, FromStr}; -use std::sync::Arc; - -use async_recursion::async_recursion; -use futures::StreamExt; -use peanuts::element::{FromContent, IntoElement}; -use peanuts::{Reader, Writer}; -use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; -use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; -use tokio_native_tls::native_tls::TlsConnector; -use tracing::{debug, instrument}; - -use crate::connection::{Tls, Unencrypted}; -use crate::error::Error; -use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType}; -use crate::stanza::client::iq::{Iq, IqType, Query}; -use crate::stanza::client::Stanza; -use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; -use crate::stanza::starttls::{Proceed, StartTls}; -use crate::stanza::stream::{Feature, Features, Stream}; -use crate::stanza::XML_VERSION; -use crate::JID; -use crate::{Connection, Result}; - -// open stream (streams started) -pub struct JabberStream { - reader: Reader>, - writer: Writer>, -} - -impl futures::Stream for JabberStream { - type Item = Result; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - pin!(self).reader.poll_next_unpin(cx).map(|content| { - content.map(|content| -> Result { - let stanza = content.map(|content| Stanza::from_content(content))?; - Ok(stanza?) - }) - }) - } -} - -impl JabberStream -where - S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, - JabberStream: std::fmt::Debug, -{ - #[instrument] - pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc) -> Result { - let sasl = SASLClient::new(sasl_config); - let mut offered_mechs: Vec<&Mechname> = Vec::new(); - for mechanism in &mechanisms.mechanisms { - offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) - } - debug!("{:?}", offered_mechs); - let mut session = sasl.start_suggested(&offered_mechs)?; - let selected_mechanism = session.get_mechname().as_str().to_owned(); - debug!("selected mech: {:?}", selected_mechanism); - let mut data: Option> = None; - - if !session.are_we_first() { - // if not first mention the mechanism then get challenge data - // mention mechanism - let auth = Auth { - mechanism: selected_mechanism, - sasl_data: "=".to_string(), - }; - self.writer.write_full(&auth).await?; - // get challenge data - let challenge: Challenge = self.reader.read().await?; - debug!("challenge: {:?}", challenge); - data = Some((*challenge).as_bytes().to_vec()); - debug!("we didn't go first"); - } else { - // if first, mention mechanism and send data - let mut sasl_data = Vec::new(); - session.step64(None, &mut sasl_data).unwrap(); - let auth = Auth { - mechanism: selected_mechanism, - sasl_data: str::from_utf8(&sasl_data)?.to_string(), - }; - debug!("{:?}", auth); - self.writer.write_full(&auth).await?; - - let server_response: ServerResponse = self.reader.read().await?; - debug!("server_response: {:#?}", server_response); - match server_response { - ServerResponse::Challenge(challenge) => { - data = Some((*challenge).as_bytes().to_vec()) - } - ServerResponse::Success(success) => { - data = success.clone().map(|success| success.as_bytes().to_vec()) - } - ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), - } - debug!("we went first"); - } - - // stepping the authentication exchange to completion - if data != None { - debug!("data: {:?}", data); - let mut sasl_data = Vec::new(); - while { - // decide if need to send more data over - let state = session - .step64(data.as_deref(), &mut sasl_data) - .expect("step errored!"); - state.is_running() - } { - // While we aren't finished, receive more data from the other party - let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); - debug!("response: {:?}", response); - let stdout = tokio::io::stdout(); - let mut writer = Writer::new(stdout); - writer.write_full(&response).await?; - self.writer.write_full(&response).await?; - debug!("response written"); - - let server_response: ServerResponse = self.reader.read().await?; - debug!("server_response: {:#?}", server_response); - match server_response { - ServerResponse::Challenge(challenge) => { - data = Some((*challenge).as_bytes().to_vec()) - } - ServerResponse::Success(success) => { - data = success.clone().map(|success| success.as_bytes().to_vec()) - } - ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), - } - } - } - let writer = self.writer.into_inner(); - let reader = self.reader.into_inner(); - let stream = reader.unsplit(writer); - Ok(stream) - } - - #[instrument] - pub async fn bind(mut self, jid: &mut JID) -> Result { - let iq_id = nanoid::nanoid!(); - if let Some(resource) = &jid.resourcepart { - let iq = Iq { - from: None, - id: iq_id.clone(), - to: None, - r#type: IqType::Set, - lang: None, - query: Some(Query::Bind(Bind { - r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), - })), - errors: Vec::new(), - }; - self.writer.write_full(&iq).await?; - let result: Iq = self.reader.read().await?; - match result { - Iq { - from: _, - id, - to: _, - r#type: IqType::Result, - lang: _, - query: - Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(new_jid))), - })), - errors: _, - } if id == iq_id => { - *jid = new_jid; - return Ok(self); - } - Iq { - from: _, - id, - to: _, - r#type: IqType::Error, - lang: _, - query: None, - errors, - } if id == iq_id => { - return Err(Error::ClientError( - errors.first().ok_or(Error::MissingError)?.clone(), - )) - } - _ => return Err(Error::UnexpectedElement(result.into_element())), - } - } else { - let iq = Iq { - from: None, - id: iq_id.clone(), - to: None, - r#type: IqType::Set, - lang: None, - query: Some(Query::Bind(Bind { r#type: None })), - errors: Vec::new(), - }; - self.writer.write_full(&iq).await?; - let result: Iq = self.reader.read().await?; - match result { - Iq { - from: _, - id, - to: _, - r#type: IqType::Result, - lang: _, - query: - Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(new_jid))), - })), - errors: _, - } if id == iq_id => { - *jid = new_jid; - return Ok(self); - } - Iq { - from: _, - id, - to: _, - r#type: IqType::Error, - lang: _, - query: None, - errors, - } if id == iq_id => { - return Err(Error::ClientError( - errors.first().ok_or(Error::MissingError)?.clone(), - )) - } - _ => return Err(Error::UnexpectedElement(result.into_element())), - } - } - } - - #[instrument] - pub async fn start_stream(connection: S, server: &mut String) -> Result { - // client to server - let (reader, writer) = tokio::io::split(connection); - let mut reader = Reader::new(reader); - let mut writer = Writer::new(writer); - - // declaration - writer.write_declaration(XML_VERSION).await?; - - // opening stream element - let stream = Stream::new_client( - None, - JID::from_str(server.as_ref())?, - None, - "en".to_string(), - ); - writer.write_start(&stream).await?; - - // server to client - - // may or may not send a declaration - let _decl = reader.read_prolog().await?; - - // receive stream element and validate - let stream: Stream = reader.read_start().await?; - debug!("got stream: {:?}", stream); - if let Some(from) = stream.from { - *server = from.to_string(); - } - - Ok(Self { reader, writer }) - } - - #[instrument] - pub async fn get_features(mut self) -> Result<(Features, Self)> { - debug!("getting features"); - let features: Features = self.reader.read().await?; - debug!("got features: {:?}", features); - Ok((features, self)) - } - - pub fn into_inner(self) -> S { - self.reader.into_inner().unsplit(self.writer.into_inner()) - } - - pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { - self.writer.write(stanza).await?; - Ok(()) - } -} - -impl JabberStream { - #[instrument] - pub async fn starttls(mut self, domain: impl AsRef + std::fmt::Debug) -> Result { - self.writer - .write_full(&StartTls { required: false }) - .await?; - let proceed: Proceed = self.reader.read().await?; - debug!("got proceed: {:?}", proceed); - let connector = TlsConnector::new().unwrap(); - let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); - if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) - .connect(domain.as_ref(), stream) - .await - { - return Ok(tls_stream); - } else { - return Err(Error::Connection); - } - } -} - -impl std::fmt::Debug for JabberStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"tls") - .finish() - } -} - -impl std::fmt::Debug for JabberStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"unencrypted") - .finish() - } -} - -#[cfg(test)] -mod tests { - use std::time::Duration; - - use super::*; - use crate::connection::Connection; - use test_log::test; - use tokio::time::sleep; - - #[test(tokio::test)] - async fn start_stream() { - // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); - // match connection { - // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), - // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), - // } - } - - #[test(tokio::test)] - async fn sasl() { - // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - // .await - // .unwrap() - // .ensure_tls() - // .await - // .unwrap(); - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - // jabber.start_stream().await.unwrap(); - - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - // jabber.reader.read_buf().await.unwrap(); - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - - // let features = jabber.get_features().await.unwrap(); - // let (sasl_config, feature) = ( - // jabber.auth.clone().unwrap(), - // features - // .features - // .iter() - // .find(|feature| matches!(feature, Feature::Sasl(_))) - // .unwrap(), - // ); - // match feature { - // Feature::StartTls(_start_tls) => todo!(), - // Feature::Sasl(mechanisms) => { - // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); - // } - // Feature::Bind => todo!(), - // Feature::Unknown => todo!(), - // } - } - - #[tokio::test] - async fn negotiate() { - // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - // .await - // .unwrap() - // .ensure_tls() - // .await - // .unwrap() - // .negotiate() - // .await - // .unwrap(); - // sleep(Duration::from_secs(5)).await - } -} diff --git a/src/jabber_stream.rs b/src/jabber_stream.rs new file mode 100644 index 0000000..8ee45b5 --- /dev/null +++ b/src/jabber_stream.rs @@ -0,0 +1,394 @@ +use std::pin::pin; +use std::str::{self, FromStr}; +use std::sync::Arc; + +use async_recursion::async_recursion; +use futures::StreamExt; +use peanuts::element::{FromContent, IntoElement}; +use peanuts::{Reader, Writer}; +use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; +use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +use tokio_native_tls::native_tls::TlsConnector; +use tracing::{debug, instrument}; + +use crate::connection::{Tls, Unencrypted}; +use crate::error::Error; +use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType}; +use crate::stanza::client::iq::{Iq, IqType, Query}; +use crate::stanza::client::Stanza; +use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; +use crate::stanza::starttls::{Proceed, StartTls}; +use crate::stanza::stream::{Feature, Features, Stream}; +use crate::stanza::XML_VERSION; +use crate::JID; +use crate::{Connection, Result}; + +// open stream (streams started) +pub struct JabberStream { + reader: Reader>, + writer: Writer>, +} + +impl futures::Stream for JabberStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + pin!(self).reader.poll_next_unpin(cx).map(|content| { + content.map(|content| -> Result { + let stanza = content.map(|content| Stanza::from_content(content))?; + Ok(stanza?) + }) + }) + } +} + +impl JabberStream +where + S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, + JabberStream: std::fmt::Debug, +{ + #[instrument] + pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc) -> Result { + let sasl = SASLClient::new(sasl_config); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in &mechanisms.mechanisms { + offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) + } + debug!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + debug!("selected mech: {:?}", selected_mechanism); + let mut data: Option> = None; + + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: "=".to_string(), + }; + self.writer.write_full(&auth).await?; + // get challenge data + let challenge: Challenge = self.reader.read().await?; + debug!("challenge: {:?}", challenge); + data = Some((*challenge).as_bytes().to_vec()); + debug!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + mechanism: selected_mechanism, + sasl_data: str::from_utf8(&sasl_data)?.to_string(), + }; + debug!("{:?}", auth); + self.writer.write_full(&auth).await?; + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), + } + debug!("we went first"); + } + + // stepping the authentication exchange to completion + if data != None { + debug!("data: {:?}", data); + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); + debug!("response: {:?}", response); + let stdout = tokio::io::stdout(); + let mut writer = Writer::new(stdout); + writer.write_full(&response).await?; + self.writer.write_full(&response).await?; + debug!("response written"); + + let server_response: ServerResponse = self.reader.read().await?; + debug!("server_response: {:#?}", server_response); + match server_response { + ServerResponse::Challenge(challenge) => { + data = Some((*challenge).as_bytes().to_vec()) + } + ServerResponse::Success(success) => { + data = success.clone().map(|success| success.as_bytes().to_vec()) + } + ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), + } + } + } + let writer = self.writer.into_inner(); + let reader = self.reader.into_inner(); + let stream = reader.unsplit(writer); + Ok(stream) + } + + #[instrument] + pub async fn bind(mut self, jid: &mut JID) -> Result { + let iq_id = nanoid::nanoid!(); + if let Some(resource) = &jid.resourcepart { + let iq = Iq { + from: None, + id: iq_id.clone(), + to: None, + r#type: IqType::Set, + lang: None, + query: Some(Query::Bind(Bind { + r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), + })), + errors: Vec::new(), + }; + self.writer.write_full(&iq).await?; + let result: Iq = self.reader.read().await?; + match result { + Iq { + from: _, + id, + to: _, + r#type: IqType::Result, + lang: _, + query: + Some(Query::Bind(Bind { + r#type: Some(BindType::Jid(FullJidType(new_jid))), + })), + errors: _, + } if id == iq_id => { + *jid = new_jid; + return Ok(self); + } + Iq { + from: _, + id, + to: _, + r#type: IqType::Error, + lang: _, + query: None, + errors, + } if id == iq_id => { + return Err(Error::ClientError( + errors.first().ok_or(Error::MissingError)?.clone(), + )) + } + _ => return Err(Error::UnexpectedElement(result.into_element())), + } + } else { + let iq = Iq { + from: None, + id: iq_id.clone(), + to: None, + r#type: IqType::Set, + lang: None, + query: Some(Query::Bind(Bind { r#type: None })), + errors: Vec::new(), + }; + self.writer.write_full(&iq).await?; + let result: Iq = self.reader.read().await?; + match result { + Iq { + from: _, + id, + to: _, + r#type: IqType::Result, + lang: _, + query: + Some(Query::Bind(Bind { + r#type: Some(BindType::Jid(FullJidType(new_jid))), + })), + errors: _, + } if id == iq_id => { + *jid = new_jid; + return Ok(self); + } + Iq { + from: _, + id, + to: _, + r#type: IqType::Error, + lang: _, + query: None, + errors, + } if id == iq_id => { + return Err(Error::ClientError( + errors.first().ok_or(Error::MissingError)?.clone(), + )) + } + _ => return Err(Error::UnexpectedElement(result.into_element())), + } + } + } + + #[instrument] + pub async fn start_stream(connection: S, server: &mut String) -> Result { + // client to server + let (reader, writer) = tokio::io::split(connection); + let mut reader = Reader::new(reader); + let mut writer = Writer::new(writer); + + // declaration + writer.write_declaration(XML_VERSION).await?; + + // opening stream element + let stream = Stream::new_client( + None, + JID::from_str(server.as_ref())?, + None, + "en".to_string(), + ); + writer.write_start(&stream).await?; + + // server to client + + // may or may not send a declaration + let _decl = reader.read_prolog().await?; + + // receive stream element and validate + let stream: Stream = reader.read_start().await?; + debug!("got stream: {:?}", stream); + if let Some(from) = stream.from { + *server = from.to_string(); + } + + Ok(Self { reader, writer }) + } + + #[instrument] + pub async fn get_features(mut self) -> Result<(Features, Self)> { + debug!("getting features"); + let features: Features = self.reader.read().await?; + debug!("got features: {:?}", features); + Ok((features, self)) + } + + pub fn into_inner(self) -> S { + self.reader.into_inner().unsplit(self.writer.into_inner()) + } + + pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { + self.writer.write(stanza).await?; + Ok(()) + } +} + +impl JabberStream { + #[instrument] + pub async fn starttls(mut self, domain: impl AsRef + std::fmt::Debug) -> Result { + self.writer + .write_full(&StartTls { required: false }) + .await?; + let proceed: Proceed = self.reader.read().await?; + debug!("got proceed: {:?}", proceed); + let connector = TlsConnector::new().unwrap(); + let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); + if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) + .connect(domain.as_ref(), stream) + .await + { + return Ok(tls_stream); + } else { + return Err(Error::Connection); + } + } +} + +impl std::fmt::Debug for JabberStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"tls") + .finish() + } +} + +impl std::fmt::Debug for JabberStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"unencrypted") + .finish() + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + use crate::connection::Connection; + use test_log::test; + use tokio::time::sleep; + + #[test(tokio::test)] + async fn start_stream() { + // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); + // match connection { + // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), + // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), + // } + } + + #[test(tokio::test)] + async fn sasl() { + // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.start_stream().await.unwrap(); + + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.reader.read_buf().await.unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + + // let features = jabber.get_features().await.unwrap(); + // let (sasl_config, feature) = ( + // jabber.auth.clone().unwrap(), + // features + // .features + // .iter() + // .find(|feature| matches!(feature, Feature::Sasl(_))) + // .unwrap(), + // ); + // match feature { + // Feature::StartTls(_start_tls) => todo!(), + // Feature::Sasl(mechanisms) => { + // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); + // } + // Feature::Bind => todo!(), + // Feature::Unknown => todo!(), + // } + } + + #[tokio::test] + async fn negotiate() { + // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap() + // .negotiate() + // .await + // .unwrap(); + // sleep(Duration::from_secs(5)).await + } +} diff --git a/src/lib.rs b/src/lib.rs index e55d3f5..43aa581 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,14 +5,14 @@ pub mod client; pub mod connection; pub mod error; -pub mod jabber; +pub mod jabber_stream; pub mod jid; pub mod stanza; pub use connection::Connection; use connection::Tls; pub use error::Error; -pub use jabber::JabberStream; +pub use jabber_stream::JabberStream; pub use jid::JID; pub type Result = std::result::Result; -- cgit