diff options
Diffstat (limited to 'jabber/src/jabber_stream.rs')
-rw-r--r-- | jabber/src/jabber_stream.rs | 482 |
1 files changed, 0 insertions, 482 deletions
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs deleted file mode 100644 index 302350d..0000000 --- a/jabber/src/jabber_stream.rs +++ /dev/null @@ -1,482 +0,0 @@ -use std::str::{self, FromStr}; -use std::sync::Arc; - -use jid::JID; -use peanuts::element::IntoElement; -use peanuts::{Reader, Writer}; -use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; -use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; -use stanza::client::iq::{Iq, IqType, Query}; -use stanza::client::Stanza; -use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; -use stanza::starttls::{Proceed, StartTls}; -use stanza::stream::{Features, Stream}; -use stanza::XML_VERSION; -use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; -use tokio_native_tls::native_tls::TlsConnector; -use tracing::{debug, instrument}; - -use crate::connection::{Tls, Unencrypted}; -use crate::error::Error; -use crate::Result; - -pub mod bound_stream; - -// open stream (streams started) -pub struct JabberStream<S> { - reader: JabberReader<S>, - writer: JabberWriter<S>, -} - -impl<S> JabberStream<S> { - fn split(self) -> (JabberReader<S>, JabberWriter<S>) { - let reader = self.reader; - let writer = self.writer; - (reader, writer) - } -} - -pub struct JabberReader<S>(Reader<ReadHalf<S>>); - -impl<S> JabberReader<S> { - // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner - fn new(reader: Reader<ReadHalf<S>>) -> Self { - Self(reader) - } - - fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> { - JabberStream { - reader: self, - writer, - } - } - - fn into_inner(self) -> Reader<ReadHalf<S>> { - self.0 - } -} - -impl<S> JabberReader<S> -where - S: AsyncRead + Unpin, -{ - pub async fn try_close(&mut self) -> Result<()> { - self.read_end_tag().await?; - Ok(()) - } -} - -impl<S> std::ops::Deref for JabberReader<S> { - type Target = Reader<ReadHalf<S>>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<S> std::ops::DerefMut for JabberReader<S> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -pub struct JabberWriter<S>(Writer<WriteHalf<S>>); - -impl<S> JabberWriter<S> { - fn new(writer: Writer<WriteHalf<S>>) -> Self { - Self(writer) - } - - fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> { - JabberStream { - reader, - writer: self, - } - } - - fn into_inner(self) -> Writer<WriteHalf<S>> { - self.0 - } -} - -impl<S> JabberWriter<S> -where - S: AsyncWrite + Unpin + Send, -{ - pub async fn try_close(&mut self) -> Result<()> { - self.write_end().await?; - Ok(()) - } -} - -impl<S> std::ops::Deref for JabberWriter<S> { - type Target = Writer<WriteHalf<S>>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<S> std::ops::DerefMut for JabberWriter<S> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl<S> JabberStream<S> -where - S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, - JabberStream<S>: std::fmt::Debug, -{ - #[instrument] - pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> { - let sasl = SASLClient::new(sasl_config); - let mut offered_mechs: Vec<&Mechname> = Vec::new(); - for mechanism in &mechanisms.mechanisms { - offered_mechs - .push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?) - } - debug!("{:?}", offered_mechs); - let mut session = sasl - .start_suggested(&offered_mechs) - .map_err(|e| Error::SASL(e.into()))?; - let selected_mechanism = session.get_mechname().as_str().to_owned(); - debug!("selected mech: {:?}", selected_mechanism); - let mut data: Option<Vec<u8>>; - - if !session.are_we_first() { - // if not first mention the mechanism then get challenge data - // mention mechanism - let auth = Auth { - mechanism: selected_mechanism, - sasl_data: "=".to_string(), - }; - self.writer.write_full(&auth).await?; - // get challenge data - let challenge: Challenge = self.reader.read().await?; - debug!("challenge: {:?}", challenge); - data = Some((*challenge).as_bytes().to_vec()); - debug!("we didn't go first"); - } else { - // if first, mention mechanism and send data - let mut sasl_data = Vec::new(); - session.step64(None, &mut sasl_data).unwrap(); - let auth = Auth { - mechanism: selected_mechanism, - sasl_data: str::from_utf8(&sasl_data)?.to_string(), - }; - debug!("{:?}", auth); - self.writer.write_full(&auth).await?; - - let server_response: ServerResponse = self.reader.read().await?; - debug!("server_response: {:#?}", server_response); - match server_response { - ServerResponse::Challenge(challenge) => { - data = Some((*challenge).as_bytes().to_vec()) - } - ServerResponse::Success(success) => { - data = success.clone().map(|success| success.as_bytes().to_vec()) - } - ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), - } - debug!("we went first"); - } - - // stepping the authentication exchange to completion - if data != None { - debug!("data: {:?}", data); - let mut sasl_data = Vec::new(); - while { - // decide if need to send more data over - let state = session - .step64(data.as_deref(), &mut sasl_data) - .expect("step errored!"); - state.is_running() - } { - // While we aren't finished, receive more data from the other party - let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); - debug!("response: {:?}", response); - self.writer.write_full(&response).await?; - debug!("response written"); - - let server_response: ServerResponse = self.reader.read().await?; - debug!("server_response: {:#?}", server_response); - match server_response { - ServerResponse::Challenge(challenge) => { - data = Some((*challenge).as_bytes().to_vec()) - } - ServerResponse::Success(success) => { - data = success.clone().map(|success| success.as_bytes().to_vec()) - } - ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), - } - } - } - let writer = self.writer.into_inner().into_inner(); - let reader = self.reader.into_inner().into_inner(); - let stream = reader.unsplit(writer); - Ok(stream) - } - - #[instrument] - pub async fn bind(mut self, jid: &mut JID) -> Result<Self> { - let iq_id = nanoid::nanoid!(); - if let Some(resource) = &jid.resourcepart { - let iq = Iq { - from: None, - id: iq_id.clone(), - to: None, - r#type: IqType::Set, - lang: None, - query: Some(Query::Bind(Bind { - r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), - })), - errors: Vec::new(), - }; - self.writer.write_full(&iq).await?; - let result: Iq = self.reader.read().await?; - match result { - Iq { - from: _, - id, - to: _, - r#type: IqType::Result, - lang: _, - query: - Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(new_jid))), - })), - errors: _, - } if id == iq_id => { - *jid = new_jid; - return Ok(self); - } - Iq { - from: _, - id, - to: _, - r#type: IqType::Error, - lang: _, - query: None, - errors, - } if id == iq_id => { - return Err(Error::ClientError( - errors.first().ok_or(Error::MissingError)?.clone(), - )) - } - _ => return Err(Error::UnexpectedElement(result.into_element())), - } - } else { - let iq = Iq { - from: None, - id: iq_id.clone(), - to: None, - r#type: IqType::Set, - lang: None, - query: Some(Query::Bind(Bind { r#type: None })), - errors: Vec::new(), - }; - self.writer.write_full(&iq).await?; - let result: Iq = self.reader.read().await?; - match result { - Iq { - from: _, - id, - to: _, - r#type: IqType::Result, - lang: _, - query: - Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(new_jid))), - })), - errors: _, - } if id == iq_id => { - *jid = new_jid; - return Ok(self); - } - Iq { - from: _, - id, - to: _, - r#type: IqType::Error, - lang: _, - query: None, - errors, - } if id == iq_id => { - return Err(Error::ClientError( - errors.first().ok_or(Error::MissingError)?.clone(), - )) - } - _ => return Err(Error::UnexpectedElement(result.into_element())), - } - } - } - - #[instrument] - pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> { - // client to server - let (reader, writer) = tokio::io::split(connection); - let mut reader = JabberReader::new(Reader::new(reader)); - let mut writer = JabberWriter::new(Writer::new(writer)); - - // declaration - writer.write_declaration(XML_VERSION).await?; - - // opening stream element - let stream = Stream::new_client( - None, - JID::from_str(server.as_ref())?, - None, - "en".to_string(), - ); - writer.write_start(&stream).await?; - - // server to client - - // may or may not send a declaration - let _decl = reader.read_prolog().await?; - - // receive stream element and validate - let stream: Stream = reader.read_start().await?; - debug!("got stream: {:?}", stream); - if let Some(from) = stream.from { - *server = from.to_string(); - } - - Ok(Self { reader, writer }) - } - - #[instrument] - pub async fn get_features(mut self) -> Result<(Features, Self)> { - debug!("getting features"); - let features: Features = self.reader.read().await?; - debug!("got features: {:?}", features); - Ok((features, self)) - } - - pub fn into_inner(self) -> S { - self.reader - .into_inner() - .into_inner() - .unsplit(self.writer.into_inner().into_inner()) - } - - pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { - self.writer.write(stanza).await?; - Ok(()) - } -} - -impl JabberStream<Unencrypted> { - #[instrument] - pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> { - self.writer - .write_full(&StartTls { required: false }) - .await?; - let proceed: Proceed = self.reader.read().await?; - debug!("got proceed: {:?}", proceed); - let connector = TlsConnector::new().unwrap(); - let stream = self - .reader - .into_inner() - .into_inner() - .unsplit(self.writer.into_inner().into_inner()); - if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) - .connect(domain.as_ref(), stream) - .await - { - return Ok(tls_stream); - } else { - return Err(Error::Connection); - } - } -} - -impl std::fmt::Debug for JabberStream<Tls> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"tls") - .finish() - } -} - -impl std::fmt::Debug for JabberStream<Unencrypted> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"unencrypted") - .finish() - } -} - -#[cfg(test)] -mod tests { - use test_log::test; - - #[test(tokio::test)] - async fn start_stream() { - // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); - // match connection { - // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), - // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), - // } - } - - #[test(tokio::test)] - async fn sasl() { - // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - // .await - // .unwrap() - // .ensure_tls() - // .await - // .unwrap(); - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - // jabber.start_stream().await.unwrap(); - - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - // jabber.reader.read_buf().await.unwrap(); - // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - // println!("data: {}", text); - - // let features = jabber.get_features().await.unwrap(); - // let (sasl_config, feature) = ( - // jabber.auth.clone().unwrap(), - // features - // .features - // .iter() - // .find(|feature| matches!(feature, Feature::Sasl(_))) - // .unwrap(), - // ); - // match feature { - // Feature::StartTls(_start_tls) => todo!(), - // Feature::Sasl(mechanisms) => { - // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); - // } - // Feature::Bind => todo!(), - // Feature::Unknown => todo!(), - // } - } - - #[tokio::test] - async fn sink() { - // let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); - // client.connect().await.unwrap(); - // let stream = client.inner().unwrap(); - // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { - // stream.writer.write(&stanza).await?; - // Ok::<JabberStream<Tls>, Error>(stream) - // }); - // todo!() - // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - // .await - // .unwrap() - // .ensure_tls() - // .await - // .unwrap() - // .negotiate() - // .await - // .unwrap(); - // sleep(Duration::from_secs(5)).await - } -} |