use std::str; use std::sync::Arc; use async_recursion::async_recursion; use peanuts::element::{FromElement, IntoElement}; use peanuts::{Reader, Writer}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf}; use tokio::time::timeout; use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, info, instrument, trace}; use trust_dns_resolver::proto::rr::domain::IntoLabel; use crate::connection::{Tls, Unencrypted}; use crate::error::Error; use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use crate::stanza::starttls::{Proceed, StartTls}; use crate::stanza::stream::{Feature, Features, Stream}; use crate::stanza::XML_VERSION; use crate::JID; use crate::{Connection, Result}; pub struct Jabber { reader: Reader>, writer: Writer>, jid: Option, auth: Option>, server: String, } impl Jabber where S: AsyncRead + AsyncWrite + Unpin, { pub fn new( reader: ReadHalf, writer: WriteHalf, jid: Option, auth: Option>, server: String, ) -> Self { let reader = Reader::new(reader); let writer = Writer::new(writer); Self { reader, writer, jid, auth, server, } } } impl Jabber where S: AsyncRead + AsyncWrite + Unpin + Send, Jabber: std::fmt::Debug, { pub async fn sasl( &mut self, mechanisms: Mechanisms, sasl_config: Arc, ) -> Result<()> { let sasl = SASLClient::new(sasl_config); let mut offered_mechs: Vec<&Mechname> = Vec::new(); for mechanism in &mechanisms.mechanisms { offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) } debug!("{:?}", offered_mechs); let mut session = sasl.start_suggested(&offered_mechs)?; let selected_mechanism = session.get_mechname().as_str().to_owned(); debug!("selected mech: {:?}", selected_mechanism); let mut data: Option> = None; if !session.are_we_first() { // if not first mention the mechanism then get challenge data // mention mechanism let auth = Auth { mechanism: selected_mechanism, sasl_data: "=".to_string(), }; self.writer.write_full(&auth).await?; // get challenge data let challenge: Challenge = self.reader.read().await?; debug!("challenge: {:?}", challenge); data = Some((*challenge).as_bytes().to_vec()); debug!("we didn't go first"); } else { // if first, mention mechanism and send data let mut sasl_data = Vec::new(); session.step64(None, &mut sasl_data).unwrap(); let auth = Auth { mechanism: selected_mechanism, sasl_data: str::from_utf8(&sasl_data)?.to_string(), }; debug!("{:?}", auth); self.writer.write_full(&auth).await?; let server_response: ServerResponse = self.reader.read().await?; debug!("server_response: {:#?}", server_response); match server_response { ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } ServerResponse::Success(success) => { data = success.clone().map(|success| success.as_bytes().to_vec()) } ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), } debug!("we went first"); } // stepping the authentication exchange to completion if data != None { debug!("data: {:?}", data); let mut sasl_data = Vec::new(); while { // decide if need to send more data over let state = session .step64(data.as_deref(), &mut sasl_data) .expect("step errored!"); state.is_running() } { // While we aren't finished, receive more data from the other party let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); debug!("response: {:?}", response); let stdout = tokio::io::stdout(); let mut writer = Writer::new(stdout); writer.write_full(&response).await?; self.writer.write_full(&response).await?; debug!("response written"); let server_response: ServerResponse = self.reader.read().await?; debug!("server_response: {:#?}", server_response); match server_response { ServerResponse::Challenge(challenge) => { data = Some((*challenge).as_bytes().to_vec()) } ServerResponse::Success(success) => { data = success.clone().map(|success| success.as_bytes().to_vec()) } ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), } } } Ok(()) } pub async fn bind(&mut self) -> Result<()> { todo!() } #[instrument] pub async fn start_stream(&mut self) -> Result<()> { // client to server // declaration self.writer.write_declaration(XML_VERSION).await?; // opening stream element let server = self.server.clone().try_into()?; let stream = Stream::new_client(None, server, None, "en".to_string()); self.writer.write_start(&stream).await?; // server to client // may or may not send a declaration let decl = self.reader.read_prolog().await?; // receive stream element and validate let text = str::from_utf8(self.reader.buffer.data()).unwrap(); debug!("data: {}", text); let stream: Stream = self.reader.read_start().await?; debug!("got stream: {:?}", stream); if let Some(from) = stream.from { self.server = from.to_string() } Ok(()) } pub async fn get_features(&mut self) -> Result { debug!("getting features"); let features: Features = self.reader.read().await?; debug!("got features: {:?}", features); Ok(features) } pub fn into_inner(self) -> S { self.reader.into_inner().unsplit(self.writer.into_inner()) } } impl Jabber { pub async fn negotiate(mut self) -> Result> { self.start_stream().await?; // TODO: timeout let features = self.get_features().await?.features; if let Some(Feature::StartTls(_)) = features .iter() .find(|feature| matches!(feature, Feature::StartTls(_s))) { let jabber = self.starttls().await?; let jabber = jabber.negotiate().await?; return Ok(jabber); } else { // TODO: better error return Err(Error::TlsRequired); } } #[async_recursion] pub async fn negotiate_tls_optional(mut self) -> Result { self.start_stream().await?; // TODO: timeout let features = self.get_features().await?.features; if let Some(Feature::StartTls(_)) = features .iter() .find(|feature| matches!(feature, Feature::StartTls(_s))) { let jabber = self.starttls().await?; let jabber = jabber.negotiate().await?; return Ok(Connection::Encrypted(jabber)); } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( self.auth.clone(), features .iter() .find(|feature| matches!(feature, Feature::Sasl(_))), ) { self.sasl(mechanisms.clone(), sasl_config).await?; let jabber = self.negotiate_tls_optional().await?; Ok(jabber) } else if let Some(Feature::Bind) = features .iter() .find(|feature| matches!(feature, Feature::Bind)) { self.bind().await?; Ok(Connection::Unencrypted(self)) } else { // TODO: better error return Err(Error::Negotiation); } } } impl Jabber { #[async_recursion] pub async fn negotiate(mut self) -> Result> { self.start_stream().await?; let features = self.get_features().await?.features; if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( self.auth.clone(), features .iter() .find(|feature| matches!(feature, Feature::Sasl(_))), ) { // TODO: avoid clone self.sasl(mechanisms.clone(), sasl_config).await?; let jabber = self.negotiate().await?; Ok(jabber) } else if let Some(Feature::Bind) = features .iter() .find(|feature| matches!(feature, Feature::Bind)) { self.bind().await?; Ok(self) } else { // TODO: better error return Err(Error::Negotiation); } } } impl Jabber { pub async fn starttls(mut self) -> Result> { self.writer .write_full(&StartTls { required: false }) .await?; let proceed: Proceed = self.reader.read().await?; debug!("got proceed: {:?}", proceed); let connector = TlsConnector::new().unwrap(); let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) .connect(&self.server, stream) .await { let (read, write) = tokio::io::split(tlsstream); let client = Jabber::new( read, write, self.jid.to_owned(), self.auth.to_owned(), self.server.to_owned(), ); return Ok(client); } else { return Err(Error::Connection); } } } impl std::fmt::Debug for Jabber { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"tls") .field("jid", &self.jid) .field("auth", &self.auth) .field("server", &self.server) .finish() } } impl std::fmt::Debug for Jabber { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"unencrypted") .field("jid", &self.jid) .field("auth", &self.auth) .field("server", &self.server) .finish() } } #[cfg(test)] mod tests { use super::*; use crate::connection::Connection; use test_log::test; #[test(tokio::test)] async fn start_stream() { let connection = Connection::connect("blos.sm", None, None).await.unwrap(); match connection { Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), } } #[test(tokio::test)] async fn sasl() { let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) .await .unwrap() .ensure_tls() .await .unwrap(); let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); println!("data: {}", text); jabber.start_stream().await.unwrap(); let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); println!("data: {}", text); jabber.reader.read_buf().await.unwrap(); let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); println!("data: {}", text); let features = jabber.get_features().await.unwrap(); let (sasl_config, feature) = ( jabber.auth.clone().unwrap(), features .features .iter() .find(|feature| matches!(feature, Feature::Sasl(_))) .unwrap(), ); match feature { Feature::StartTls(_start_tls) => todo!(), Feature::Sasl(mechanisms) => { jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); } Feature::Bind => todo!(), Feature::Unknown => todo!(), } } }