aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src/jabber_stream.rs
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-12-04 18:18:37 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-12-04 18:18:37 +0000
commit1b91ff690488b65b552c90bd5392b9a300c8c981 (patch)
tree9c290f69b26eba0393d7bbc05ba29c28ea74a26e /jabber/src/jabber_stream.rs
parent03764f8cedb3f0a55a61be0f0a59faaa6357a83a (diff)
downloadluz-1b91ff690488b65b552c90bd5392b9a300c8c981.tar.gz
luz-1b91ff690488b65b552c90bd5392b9a300c8c981.tar.bz2
luz-1b91ff690488b65b552c90bd5392b9a300c8c981.zip
use cargo workspace
Diffstat (limited to 'jabber/src/jabber_stream.rs')
-rw-r--r--jabber/src/jabber_stream.rs393
1 files changed, 393 insertions, 0 deletions
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs
new file mode 100644
index 0000000..dd0dcbf
--- /dev/null
+++ b/jabber/src/jabber_stream.rs
@@ -0,0 +1,393 @@
+use std::pin::pin;
+use std::str::{self, FromStr};
+use std::sync::Arc;
+
+use futures::StreamExt;
+use jid::JID;
+use peanuts::element::{FromContent, IntoElement};
+use peanuts::{Reader, Writer};
+use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
+use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
+use stanza::client::iq::{Iq, IqType, Query};
+use stanza::client::Stanza;
+use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
+use stanza::starttls::{Proceed, StartTls};
+use stanza::stream::{Features, Stream};
+use stanza::XML_VERSION;
+use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
+use tokio_native_tls::native_tls::TlsConnector;
+use tracing::{debug, instrument};
+
+use crate::connection::{Tls, Unencrypted};
+use crate::error::Error;
+use crate::Result;
+
+// open stream (streams started)
+pub struct JabberStream<S> {
+ reader: Reader<ReadHalf<S>>,
+ writer: Writer<WriteHalf<S>>,
+}
+
+impl<S: AsyncRead> futures::Stream for JabberStream<S> {
+ type Item = Result<Stanza>;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ pin!(self).reader.poll_next_unpin(cx).map(|content| {
+ content.map(|content| -> Result<Stanza> {
+ let stanza = content.map(|content| Stanza::from_content(content))?;
+ Ok(stanza?)
+ })
+ })
+ }
+}
+
+impl<S> JabberStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
+ JabberStream<S>: std::fmt::Debug,
+{
+ #[instrument]
+ pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
+ let sasl = SASLClient::new(sasl_config);
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in &mechanisms.mechanisms {
+ offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
+ }
+ debug!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ debug!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: "=".to_string(),
+ };
+ self.writer.write_full(&auth).await?;
+ // get challenge data
+ let challenge: Challenge = self.reader.read().await?;
+ debug!("challenge: {:?}", challenge);
+ data = Some((*challenge).as_bytes().to_vec());
+ debug!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: str::from_utf8(&sasl_data)?.to_string(),
+ };
+ debug!("{:?}", auth);
+ self.writer.write_full(&auth).await?;
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
+ }
+ debug!("we went first");
+ }
+
+ // stepping the authentication exchange to completion
+ if data != None {
+ debug!("data: {:?}", data);
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
+ debug!("response: {:?}", response);
+ let stdout = tokio::io::stdout();
+ let mut writer = Writer::new(stdout);
+ writer.write_full(&response).await?;
+ self.writer.write_full(&response).await?;
+ debug!("response written");
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
+ }
+ }
+ }
+ let writer = self.writer.into_inner();
+ let reader = self.reader.into_inner();
+ let stream = reader.unsplit(writer);
+ Ok(stream)
+ }
+
+ #[instrument]
+ pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
+ let iq_id = nanoid::nanoid!();
+ if let Some(resource) = &jid.resourcepart {
+ let iq = Iq {
+ from: None,
+ id: iq_id.clone(),
+ to: None,
+ r#type: IqType::Set,
+ lang: None,
+ query: Some(Query::Bind(Bind {
+ r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
+ })),
+ errors: Vec::new(),
+ };
+ self.writer.write_full(&iq).await?;
+ let result: Iq = self.reader.read().await?;
+ match result {
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Result,
+ lang: _,
+ query:
+ Some(Query::Bind(Bind {
+ r#type: Some(BindType::Jid(FullJidType(new_jid))),
+ })),
+ errors: _,
+ } if id == iq_id => {
+ *jid = new_jid;
+ return Ok(self);
+ }
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Error,
+ lang: _,
+ query: None,
+ errors,
+ } if id == iq_id => {
+ return Err(Error::ClientError(
+ errors.first().ok_or(Error::MissingError)?.clone(),
+ ))
+ }
+ _ => return Err(Error::UnexpectedElement(result.into_element())),
+ }
+ } else {
+ let iq = Iq {
+ from: None,
+ id: iq_id.clone(),
+ to: None,
+ r#type: IqType::Set,
+ lang: None,
+ query: Some(Query::Bind(Bind { r#type: None })),
+ errors: Vec::new(),
+ };
+ self.writer.write_full(&iq).await?;
+ let result: Iq = self.reader.read().await?;
+ match result {
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Result,
+ lang: _,
+ query:
+ Some(Query::Bind(Bind {
+ r#type: Some(BindType::Jid(FullJidType(new_jid))),
+ })),
+ errors: _,
+ } if id == iq_id => {
+ *jid = new_jid;
+ return Ok(self);
+ }
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Error,
+ lang: _,
+ query: None,
+ errors,
+ } if id == iq_id => {
+ return Err(Error::ClientError(
+ errors.first().ok_or(Error::MissingError)?.clone(),
+ ))
+ }
+ _ => return Err(Error::UnexpectedElement(result.into_element())),
+ }
+ }
+ }
+
+ #[instrument]
+ pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
+ // client to server
+ let (reader, writer) = tokio::io::split(connection);
+ let mut reader = Reader::new(reader);
+ let mut writer = Writer::new(writer);
+
+ // declaration
+ writer.write_declaration(XML_VERSION).await?;
+
+ // opening stream element
+ let stream = Stream::new_client(
+ None,
+ JID::from_str(server.as_ref())?,
+ None,
+ "en".to_string(),
+ );
+ writer.write_start(&stream).await?;
+
+ // server to client
+
+ // may or may not send a declaration
+ let _decl = reader.read_prolog().await?;
+
+ // receive stream element and validate
+ let stream: Stream = reader.read_start().await?;
+ debug!("got stream: {:?}", stream);
+ if let Some(from) = stream.from {
+ *server = from.to_string();
+ }
+
+ Ok(Self { reader, writer })
+ }
+
+ #[instrument]
+ pub async fn get_features(mut self) -> Result<(Features, Self)> {
+ debug!("getting features");
+ let features: Features = self.reader.read().await?;
+ debug!("got features: {:?}", features);
+ Ok((features, self))
+ }
+
+ pub fn into_inner(self) -> S {
+ self.reader.into_inner().unsplit(self.writer.into_inner())
+ }
+
+ pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
+ self.writer.write(stanza).await?;
+ Ok(())
+ }
+}
+
+impl JabberStream<Unencrypted> {
+ #[instrument]
+ pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
+ self.writer
+ .write_full(&StartTls { required: false })
+ .await?;
+ let proceed: Proceed = self.reader.read().await?;
+ debug!("got proceed: {:?}", proceed);
+ let connector = TlsConnector::new().unwrap();
+ let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
+ if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
+ .connect(domain.as_ref(), stream)
+ .await
+ {
+ return Ok(tls_stream);
+ } else {
+ return Err(Error::Connection);
+ }
+ }
+}
+
+impl std::fmt::Debug for JabberStream<Tls> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Jabber")
+ .field("connection", &"tls")
+ .finish()
+ }
+}
+
+impl std::fmt::Debug for JabberStream<Unencrypted> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Jabber")
+ .field("connection", &"unencrypted")
+ .finish()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use super::*;
+ use crate::connection::Connection;
+ use test_log::test;
+ use tokio::time::sleep;
+
+ #[test(tokio::test)]
+ async fn start_stream() {
+ // let connection = Connection::connect("blos.sm", None, None).await.unwrap();
+ // match connection {
+ // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
+ // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
+ // }
+ }
+
+ #[test(tokio::test)]
+ async fn sasl() {
+ // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap();
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+ // jabber.start_stream().await.unwrap();
+
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+ // jabber.reader.read_buf().await.unwrap();
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+
+ // let features = jabber.get_features().await.unwrap();
+ // let (sasl_config, feature) = (
+ // jabber.auth.clone().unwrap(),
+ // features
+ // .features
+ // .iter()
+ // .find(|feature| matches!(feature, Feature::Sasl(_)))
+ // .unwrap(),
+ // );
+ // match feature {
+ // Feature::StartTls(_start_tls) => todo!(),
+ // Feature::Sasl(mechanisms) => {
+ // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
+ // }
+ // Feature::Bind => todo!(),
+ // Feature::Unknown => todo!(),
+ // }
+ }
+
+ #[tokio::test]
+ async fn negotiate() {
+ // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap()
+ // .negotiate()
+ // .await
+ // .unwrap();
+ // sleep(Duration::from_secs(5)).await
+ }
+}