use std::str::{self, FromStr}; use std::sync::Arc; use jid::JID; use peanuts::element::IntoElement; use peanuts::{Reader, Writer}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use stanza::client::iq::{Iq, IqType, Query}; use stanza::client::Stanza; use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use stanza::starttls::{Proceed, StartTls}; use stanza::stream::{Features, Stream}; use stanza::XML_VERSION; use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, instrument}; use crate::connection::{Tls, Unencrypted}; use crate::error::Error; use crate::Result; pub mod bound_stream; // open stream (streams started) pub struct JabberStream { reader: JabberReader, writer: JabberWriter, } impl JabberStream { fn split(self) -> (JabberReader, JabberWriter) { let reader = self.reader; let writer = self.writer; (reader, writer) } } pub struct JabberReader(Reader>); impl JabberReader { // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner fn new(reader: Reader>) -> Self { Self(reader) } fn unsplit(self, writer: JabberWriter) -> JabberStream { JabberStream { reader: self, writer, } } fn into_inner(self) -> Reader> { self.0 } } impl JabberReader where S: AsyncRead + Unpin, { pub async fn try_close(&mut self) -> Result<()> { self.read_end_tag().await?; Ok(()) } } impl std::ops::Deref for JabberReader { type Target = Reader>; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for JabberReader { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } pub struct JabberWriter(Writer>); impl JabberWriter { fn new(writer: Writer>) -> Self { Self(writer) } fn unsplit(self, reader: JabberReader) -> JabberStream { JabberStream { reader, writer: self, } } fn into_inner(self) -> Writer> { self.0 } } impl JabberWriter where S: AsyncWrite + Unpin + Send, { pub async fn try_close(&mut self) -> Result<()> { self.write_end().await?; Ok(()) } } impl std::ops::Deref for JabberWriter { type Target = Writer>; fn deref(&self) -> &Self::Target { &self.0 } } impl std::ops::DerefMut for JabberWriter { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } impl JabberStream where S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, JabberStream: std::fmt::Debug, { #[instrument] pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc) -> Result { let sasl = SASLClient::new(sasl_config); let mut offered_mechs: Vec<&Mechname> = Vec::new(); for mechanism in &mechanisms.mechanisms { offered_mechs .push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?) } debug!("{:?}", offered_mechs); let mut session = sasl .start_suggested(&offered_mechs) .map_err(|e| Error::SASL(e.into()))?; let selected_mechanism = session.get_mechname().as_str().to_owned(); debug!("selected mech: {:?}", selected_mechanism); let mut data: Option>; if !session.are_we_first() { // if not first mention the mechanism then get challenge data // mention mechanism let auth = Auth { mechanism: selected_mechanism, sasl_data: "=".to_string(), }; self.writer.write_full(&auth).await?; // get challenge data let challenge: Challenge = self.reader.read().await?; debug!("challenge: {:?}", challenge); data = Some((*challenge).as_bytes().to_vec()); debug!("we didn't go first"); } else { // if first, mention mechanism and send data let mut sasl_data = Vec::new(); session.step64(None, &mut sasl_data).unwrap(); let auth = Auth { mechanism: selected_mechanism, sasl_data: str::from_utf8(&sasl_data)?.to_string(), }; debug!("{:?}", auth); self.writer.write_full(&auth).await?; let server_response: ServerResponse = self.reader.read().await?; debug!("server_response: {:#?}", server_response); match server_response { ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } ServerResponse::Success(success) => { data = success.clone().map(|success| success.as_bytes().to_vec()) } ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), } debug!("we went first"); } // stepping the authentication exchange to completion if data != None { debug!("data: {:?}", data); let mut sasl_data = Vec::new(); while { // decide if need to send more data over let state = session .step64(data.as_deref(), &mut sasl_data) .expect("step errored!"); state.is_running() } { // While we aren't finished, receive more data from the other party let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); debug!("response: {:?}", response); self.writer.write_full(&response).await?; debug!("response written"); let server_response: ServerResponse = self.reader.read().await?; debug!("server_response: {:#?}", server_response); match server_response { ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } ServerResponse::Success(success) => { data = success.clone().map(|success| success.as_bytes().to_vec()) } ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())), } } } let writer = self.writer.into_inner().into_inner(); let reader = self.reader.into_inner().into_inner(); let stream = reader.unsplit(writer); Ok(stream) } #[instrument] pub async fn bind(mut self, jid: &mut JID) -> Result { let iq_id = nanoid::nanoid!(); if let Some(resource) = &jid.resourcepart { let iq = Iq { from: None, id: iq_id.clone(), to: None, r#type: IqType::Set, lang: None, query: Some(Query::Bind(Bind { r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), })), errors: Vec::new(), }; self.writer.write_full(&iq).await?; let result: Iq = self.reader.read().await?; match result { Iq { from: _, id, to: _, r#type: IqType::Result, lang: _, query: Some(Query::Bind(Bind { r#type: Some(BindType::Jid(FullJidType(new_jid))), })), errors: _, } if id == iq_id => { *jid = new_jid; return Ok(self); } Iq { from: _, id, to: _, r#type: IqType::Error, lang: _, query: None, errors, } if id == iq_id => { return Err(Error::ClientError( errors.first().ok_or(Error::MissingError)?.clone(), )) } _ => return Err(Error::UnexpectedElement(result.into_element())), } } else { let iq = Iq { from: None, id: iq_id.clone(), to: None, r#type: IqType::Set, lang: None, query: Some(Query::Bind(Bind { r#type: None })), errors: Vec::new(), }; self.writer.write_full(&iq).await?; let result: Iq = self.reader.read().await?; match result { Iq { from: _, id, to: _, r#type: IqType::Result, lang: _, query: Some(Query::Bind(Bind { r#type: Some(BindType::Jid(FullJidType(new_jid))), })), errors: _, } if id == iq_id => { *jid = new_jid; return Ok(self); } Iq { from: _, id, to: _, r#type: IqType::Error, lang: _, query: None, errors, } if id == iq_id => { return Err(Error::ClientError( errors.first().ok_or(Error::MissingError)?.clone(), )) } _ => return Err(Error::UnexpectedElement(result.into_element())), } } } #[instrument] pub async fn start_stream(connection: S, server: &mut String) -> Result { // client to server let (reader, writer) = tokio::io::split(connection); let mut reader = JabberReader::new(Reader::new(reader)); let mut writer = JabberWriter::new(Writer::new(writer)); // declaration writer.write_declaration(XML_VERSION).await?; // opening stream element let stream = Stream::new_client( None, JID::from_str(server.as_ref())?, None, "en".to_string(), ); writer.write_start(&stream).await?; // server to client // may or may not send a declaration let _decl = reader.read_prolog().await?; // receive stream element and validate let stream: Stream = reader.read_start().await?; debug!("got stream: {:?}", stream); if let Some(from) = stream.from { *server = from.to_string(); } Ok(Self { reader, writer }) } #[instrument] pub async fn get_features(mut self) -> Result<(Features, Self)> { debug!("getting features"); let features: Features = self.reader.read().await?; debug!("got features: {:?}", features); Ok((features, self)) } pub fn into_inner(self) -> S { self.reader .into_inner() .into_inner() .unsplit(self.writer.into_inner().into_inner()) } pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { self.writer.write(stanza).await?; Ok(()) } } impl JabberStream { #[instrument] pub async fn starttls(mut self, domain: impl AsRef + std::fmt::Debug) -> Result { self.writer .write_full(&StartTls { required: false }) .await?; let proceed: Proceed = self.reader.read().await?; debug!("got proceed: {:?}", proceed); let connector = TlsConnector::new().unwrap(); let stream = self .reader .into_inner() .into_inner() .unsplit(self.writer.into_inner().into_inner()); if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) .connect(domain.as_ref(), stream) .await { return Ok(tls_stream); } else { return Err(Error::Connection); } } } impl std::fmt::Debug for JabberStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"tls") .finish() } } impl std::fmt::Debug for JabberStream { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"unencrypted") .finish() } } #[cfg(test)] mod tests { use test_log::test; #[test(tokio::test)] async fn start_stream() { // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); // match connection { // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), // } } #[test(tokio::test)] async fn sasl() { // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // .await // .unwrap() // .ensure_tls() // .await // .unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // println!("data: {}", text); // jabber.start_stream().await.unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // println!("data: {}", text); // jabber.reader.read_buf().await.unwrap(); // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); // println!("data: {}", text); // let features = jabber.get_features().await.unwrap(); // let (sasl_config, feature) = ( // jabber.auth.clone().unwrap(), // features // .features // .iter() // .find(|feature| matches!(feature, Feature::Sasl(_))) // .unwrap(), // ); // match feature { // Feature::StartTls(_start_tls) => todo!(), // Feature::Sasl(mechanisms) => { // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); // } // Feature::Bind => todo!(), // Feature::Unknown => todo!(), // } } #[tokio::test] async fn sink() { // let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap(); // client.connect().await.unwrap(); // let stream = client.inner().unwrap(); // let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move { // stream.writer.write(&stanza).await?; // Ok::, Error>(stream) // }); // todo!() // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) // .await // .unwrap() // .ensure_tls() // .await // .unwrap() // .negotiate() // .await // .unwrap(); // sleep(Duration::from_secs(5)).await } }