diff options
author | 2024-12-03 23:57:04 +0000 | |
---|---|---|
committer | 2024-12-03 23:57:04 +0000 | |
commit | e0373c0520e7fae792bc907e9c500ab846d34e31 (patch) | |
tree | fcec4d201c85ac951500f6678824024be87a1b5e /src/jabber.rs | |
parent | 7c2577d196c059ab6e2d5b0efe5e036bdad75be7 (diff) | |
download | luz-e0373c0520e7fae792bc907e9c500ab846d34e31.tar.gz luz-e0373c0520e7fae792bc907e9c500ab846d34e31.tar.bz2 luz-e0373c0520e7fae792bc907e9c500ab846d34e31.zip |
WIP: connecting fsm
Diffstat (limited to 'src/jabber.rs')
-rw-r--r-- | src/jabber.rs | 377 |
1 files changed, 172 insertions, 205 deletions
diff --git a/src/jabber.rs b/src/jabber.rs index d5cfe13..cf90f73 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,4 +1,4 @@ -use std::str; +use std::str::{self, FromStr}; use std::sync::Arc; use async_recursion::async_recursion; @@ -20,47 +20,18 @@ use crate::stanza::XML_VERSION; use crate::JID; use crate::{Connection, Result}; -pub struct Jabber<S> { +// open stream (streams started) +pub struct JabberStream<S> { reader: Reader<ReadHalf<S>>, writer: Writer<WriteHalf<S>>, - jid: Option<JID>, - auth: Option<Arc<SASLConfig>>, - server: String, } -impl<S> Jabber<S> +impl<S> JabberStream<S> where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, + JabberStream<S>: std::fmt::Debug, { - pub fn new( - reader: ReadHalf<S>, - writer: WriteHalf<S>, - jid: Option<JID>, - auth: Option<Arc<SASLConfig>>, - server: String, - ) -> Self { - let reader = Reader::new(reader); - let writer = Writer::new(writer); - Self { - reader, - writer, - jid, - auth, - server, - } - } -} - -impl<S> Jabber<S> -where - S: AsyncRead + AsyncWrite + Unpin + Send, - Jabber<S>: std::fmt::Debug, -{ - pub async fn sasl( - &mut self, - mechanisms: Mechanisms, - sasl_config: Arc<SASLConfig>, - ) -> Result<()> { + pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> { let sasl = SASLClient::new(sasl_config); let mut offered_mechs: Vec<&Mechname> = Vec::new(); for mechanism in &mechanisms.mechanisms { @@ -143,12 +114,15 @@ where } } } - Ok(()) + let writer = self.writer.into_inner(); + let reader = self.reader.into_inner(); + let stream = reader.unsplit(writer); + Ok(stream) } - pub async fn bind(&mut self) -> Result<()> { + pub async fn bind(mut self, jid: &mut JID) -> Result<Self> { let iq_id = nanoid::nanoid!(); - if let Some(resource) = self.jid.clone().unwrap().resourcepart { + if let Some(resource) = &jid.resourcepart { let iq = Iq { from: None, id: iq_id.clone(), @@ -156,7 +130,7 @@ where r#type: IqType::Set, lang: None, query: Some(Query::Bind(Bind { - r#type: Some(BindType::Resource(ResourceType(resource))), + r#type: Some(BindType::Resource(ResourceType(resource.to_string()))), })), errors: Vec::new(), }; @@ -171,12 +145,12 @@ where lang: _, query: Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(jid))), + r#type: Some(BindType::Jid(FullJidType(new_jid))), })), errors: _, } if id == iq_id => { - self.jid = Some(jid); - return Ok(()); + *jid = new_jid; + return Ok(self); } Iq { from: _, @@ -214,12 +188,12 @@ where lang: _, query: Some(Query::Bind(Bind { - r#type: Some(BindType::Jid(FullJidType(jid))), + r#type: Some(BindType::Jid(FullJidType(new_jid))), })), errors: _, } if id == iq_id => { - self.jid = Some(jid); - return Ok(()); + *jid = new_jid; + return Ok(self); } Iq { from: _, @@ -240,39 +214,44 @@ where } #[instrument] - pub async fn start_stream(&mut self) -> Result<()> { + pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> { // client to server + let (reader, writer) = tokio::io::split(connection); + let mut reader = Reader::new(reader); + let mut writer = Writer::new(writer); // declaration - self.writer.write_declaration(XML_VERSION).await?; + writer.write_declaration(XML_VERSION).await?; // opening stream element - let server = self.server.clone().try_into()?; - let stream = Stream::new_client(None, server, None, "en".to_string()); - self.writer.write_start(&stream).await?; + let stream = Stream::new_client( + None, + JID::from_str(server.as_ref())?, + None, + "en".to_string(), + ); + writer.write_start(&stream).await?; // server to client // may or may not send a declaration - let _decl = self.reader.read_prolog().await?; + let _decl = reader.read_prolog().await?; // receive stream element and validate - let text = str::from_utf8(self.reader.buffer.data()).unwrap(); - debug!("data: {}", text); - let stream: Stream = self.reader.read_start().await?; + let stream: Stream = reader.read_start().await?; debug!("got stream: {:?}", stream); if let Some(from) = stream.from { - self.server = from.to_string() + *server = from.to_string(); } - Ok(()) + Ok(Self { reader, writer }) } - pub async fn get_features(&mut self) -> Result<Features> { + pub async fn get_features(mut self) -> Result<(Features, Self)> { debug!("getting features"); let features: Features = self.reader.read().await?; debug!("got features: {:?}", features); - Ok(features) + Ok((features, self)) } pub fn into_inner(self) -> S { @@ -280,89 +259,89 @@ where } } -impl Jabber<Unencrypted> { - pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> { - self.start_stream().await?; - // TODO: timeout - let features = self.get_features().await?.features; - if let Some(Feature::StartTls(_)) = features - .iter() - .find(|feature| matches!(feature, Feature::StartTls(_s))) - { - let jabber = self.starttls().await?; - let jabber = jabber.negotiate().await?; - return Ok(jabber); - } else { - // TODO: better error - return Err(Error::TlsRequired); - } - } - - #[async_recursion] - pub async fn negotiate_tls_optional(mut self) -> Result<Connection> { - self.start_stream().await?; - // TODO: timeout - let features = self.get_features().await?.features; - if let Some(Feature::StartTls(_)) = features - .iter() - .find(|feature| matches!(feature, Feature::StartTls(_s))) - { - let jabber = self.starttls().await?; - let jabber = jabber.negotiate().await?; - return Ok(Connection::Encrypted(jabber)); - } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( - self.auth.clone(), - features - .iter() - .find(|feature| matches!(feature, Feature::Sasl(_))), - ) { - self.sasl(mechanisms.clone(), sasl_config).await?; - let jabber = self.negotiate_tls_optional().await?; - Ok(jabber) - } else if let Some(Feature::Bind) = features - .iter() - .find(|feature| matches!(feature, Feature::Bind)) - { - self.bind().await?; - Ok(Connection::Unencrypted(self)) - } else { - // TODO: better error - return Err(Error::Negotiation); - } - } +impl JabberStream<Unencrypted> { + // pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>( + // mut self, + // features: Features, + // ) -> Result<Feature> { + // // TODO: timeout + // if let Some(Feature::StartTls(_)) = features + // .features + // .iter() + // .find(|feature| matches!(feature, Feature::StartTls(_s))) + // { + // return Ok(self); + // } else { + // // TODO: better error + // return Err(Error::TlsRequired); + // } + // } + + // #[async_recursion] + // pub async fn negotiate_tls_optional(mut self) -> Result<Connection> { + // self.start_stream().await?; + // // TODO: timeout + // let features = self.get_features().await?.features; + // if let Some(Feature::StartTls(_)) = features + // .iter() + // .find(|feature| matches!(feature, Feature::StartTls(_s))) + // { + // let jabber = self.starttls().await?; + // let jabber = jabber.negotiate().await?; + // return Ok(Connection::Encrypted(jabber)); + // } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + // self.auth.clone(), + // features + // .iter() + // .find(|feature| matches!(feature, Feature::Sasl(_))), + // ) { + // self.sasl(mechanisms.clone(), sasl_config).await?; + // let jabber = self.negotiate_tls_optional().await?; + // Ok(jabber) + // } else if let Some(Feature::Bind) = features + // .iter() + // .find(|feature| matches!(feature, Feature::Bind)) + // { + // self.bind().await?; + // Ok(Connection::Unencrypted(self)) + // } else { + // // TODO: better error + // return Err(Error::Negotiation); + // } + // } } -impl Jabber<Tls> { - #[async_recursion] - pub async fn negotiate(mut self) -> Result<Jabber<Tls>> { - self.start_stream().await?; - let features = self.get_features().await?.features; - - if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( - self.auth.clone(), - features - .iter() - .find(|feature| matches!(feature, Feature::Sasl(_))), - ) { - // TODO: avoid clone - self.sasl(mechanisms.clone(), sasl_config).await?; - let jabber = self.negotiate().await?; - Ok(jabber) - } else if let Some(Feature::Bind) = features - .iter() - .find(|feature| matches!(feature, Feature::Bind)) - { - self.bind().await?; - Ok(self) - } else { - // TODO: better error - return Err(Error::Negotiation); - } - } +impl JabberStream<Tls> { + // #[async_recursion] + // pub async fn negotiate(mut self) -> Result<JabberStream<Tls>> { + // self.start_stream().await?; + // let features = self.get_features().await?.features; + + // if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( + // self.auth.clone(), + // features + // .iter() + // .find(|feature| matches!(feature, Feature::Sasl(_))), + // ) { + // // TODO: avoid clone + // self.sasl(mechanisms.clone(), sasl_config).await?; + // let jabber = self.negotiate().await?; + // Ok(jabber) + // } else if let Some(Feature::Bind) = features + // .iter() + // .find(|feature| matches!(feature, Feature::Bind)) + // { + // self.bind().await?; + // Ok(self) + // } else { + // // TODO: better error + // return Err(Error::Negotiation); + // } + // } } -impl Jabber<Unencrypted> { - pub async fn starttls(mut self) -> Result<Jabber<Tls>> { +impl JabberStream<Unencrypted> { + pub async fn starttls(mut self, domain: impl AsRef<str>) -> Result<Tls> { self.writer .write_full(&StartTls { required: false }) .await?; @@ -370,43 +349,31 @@ impl Jabber<Unencrypted> { debug!("got proceed: {:?}", proceed); let connector = TlsConnector::new().unwrap(); let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); - if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) - .connect(&self.server, stream) + if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) + .connect(domain.as_ref(), stream) .await { - let (read, write) = tokio::io::split(tlsstream); - let client = Jabber::new( - read, - write, - self.jid.to_owned(), - self.auth.to_owned(), - self.server.to_owned(), - ); - return Ok(client); + // let (read, write) = tokio::io::split(tlsstream); + // let client = JabberStream::new(read, write); + return Ok(tls_stream); } else { return Err(Error::Connection); } } } -impl std::fmt::Debug for Jabber<Tls> { +impl std::fmt::Debug for JabberStream<Tls> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"tls") - .field("jid", &self.jid) - .field("auth", &self.auth) - .field("server", &self.server) .finish() } } -impl std::fmt::Debug for Jabber<Unencrypted> { +impl std::fmt::Debug for JabberStream<Unencrypted> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"unencrypted") - .field("jid", &self.jid) - .field("auth", &self.auth) - .field("server", &self.server) .finish() } } @@ -422,61 +389,61 @@ mod tests { #[test(tokio::test)] async fn start_stream() { - let connection = Connection::connect("blos.sm", None, None).await.unwrap(); - match connection { - Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), - Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), - } + // let connection = Connection::connect("blos.sm", None, None).await.unwrap(); + // match connection { + // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), + // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), + // } } #[test(tokio::test)] async fn sasl() { - let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - .await - .unwrap() - .ensure_tls() - .await - .unwrap(); - let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - println!("data: {}", text); - jabber.start_stream().await.unwrap(); - - let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - println!("data: {}", text); - jabber.reader.read_buf().await.unwrap(); - let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); - println!("data: {}", text); - - let features = jabber.get_features().await.unwrap(); - let (sasl_config, feature) = ( - jabber.auth.clone().unwrap(), - features - .features - .iter() - .find(|feature| matches!(feature, Feature::Sasl(_))) - .unwrap(), - ); - match feature { - Feature::StartTls(_start_tls) => todo!(), - Feature::Sasl(mechanisms) => { - jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); - } - Feature::Bind => todo!(), - Feature::Unknown => todo!(), - } + // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.start_stream().await.unwrap(); + + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + // jabber.reader.read_buf().await.unwrap(); + // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); + // println!("data: {}", text); + + // let features = jabber.get_features().await.unwrap(); + // let (sasl_config, feature) = ( + // jabber.auth.clone().unwrap(), + // features + // .features + // .iter() + // .find(|feature| matches!(feature, Feature::Sasl(_))) + // .unwrap(), + // ); + // match feature { + // Feature::StartTls(_start_tls) => todo!(), + // Feature::Sasl(mechanisms) => { + // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); + // } + // Feature::Bind => todo!(), + // Feature::Unknown => todo!(), + // } } #[tokio::test] async fn negotiate() { - let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) - .await - .unwrap() - .ensure_tls() - .await - .unwrap() - .negotiate() - .await - .unwrap(); - sleep(Duration::from_secs(5)).await + // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap() + // .negotiate() + // .await + // .unwrap(); + // sleep(Duration::from_secs(5)).await } } |