diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/connection.rs | 35 | ||||
| -rw-r--r-- | src/error.rs | 2 | ||||
| -rw-r--r-- | src/jabber.rs | 220 | ||||
| -rw-r--r-- | src/stanza/sasl.rs | 169 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 42 | 
5 files changed, 421 insertions, 47 deletions
| diff --git a/src/connection.rs b/src/connection.rs index 65e9383..9e485d3 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,16 +1,18 @@  use std::net::{IpAddr, SocketAddr};  use std::str;  use std::str::FromStr; +use std::sync::Arc; +use rsasl::config::SASLConfig;  use tokio::net::TcpStream;  use tokio_native_tls::native_tls::TlsConnector;  // TODO: use rustls  use tokio_native_tls::TlsStream;  use tracing::{debug, info, instrument, trace}; -use crate::Error;  use crate::Jabber;  use crate::Result; +use crate::{Error, JID};  pub type Tls = TlsStream<TcpStream>;  pub type Unencrypted = TcpStream; @@ -37,15 +39,20 @@ impl Connection {          }      } -    // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> { -    //     let server = jid.domainpart.clone(); -    //     let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; -    //     println!("auth: {:?}", auth); -    //     Self::connect(&server, jid.try_into()?, Some(auth)).await -    // } +    pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> { +        let jid: JID = JID::from_str(jid.as_ref())?; +        let server = jid.domainpart.clone(); +        let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?; +        println!("auth: {:?}", auth); +        Self::connect(&server, Some(jid), Some(auth)).await +    }      #[instrument] -    pub async fn connect(server: &str) -> Result<Self> { +    pub async fn connect( +        server: &str, +        jid: Option<JID>, +        auth: Option<Arc<SASLConfig>>, +    ) -> Result<Self> {          info!("connecting to {}", server);          let sockets = Self::get_sockets(&server).await;          debug!("discovered sockets: {:?}", sockets); @@ -58,8 +65,8 @@ impl Connection {                          return Ok(Self::Encrypted(Jabber::new(                              readhalf,                              writehalf, -                            None, -                            None, +                            jid, +                            auth,                              server.to_owned(),                          )));                      } @@ -71,8 +78,8 @@ impl Connection {                          return Ok(Self::Unencrypted(Jabber::new(                              readhalf,                              writehalf, -                            None, -                            None, +                            jid, +                            auth,                              server.to_owned(),                          )));                      } @@ -181,12 +188,12 @@ mod tests {      #[test(tokio::test)]      async fn connect() { -        Connection::connect("blos.sm").await.unwrap(); +        Connection::connect("blos.sm", None, None).await.unwrap();      }      #[test(tokio::test)]      async fn test_tls() { -        Connection::connect("blos.sm") +        Connection::connect("blos.sm", None, None)              .await              .unwrap()              .ensure_tls() diff --git a/src/error.rs b/src/error.rs index c7c867c..8ee9077 100644 --- a/src/error.rs +++ b/src/error.rs @@ -19,6 +19,8 @@ pub enum Error {      IDMismatch,      BindError,      ParseError, +    Negotiation, +    TlsRequired,      UnexpectedEnd,      UnexpectedElement,      UnexpectedText, diff --git a/src/jabber.rs b/src/jabber.rs index a56c65c..9e7f9d8 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,26 +1,26 @@  use std::str;  use std::sync::Arc; +use async_recursion::async_recursion;  use peanuts::element::{FromElement, IntoElement};  use peanuts::{Reader, Writer}; -use rsasl::prelude::SASLConfig; +use rsasl::prelude::{Mechname, SASLClient, SASLConfig};  use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::time::timeout;  use tokio_native_tls::native_tls::TlsConnector;  use tracing::{debug, info, instrument, trace};  use trust_dns_resolver::proto::rr::domain::IntoLabel;  use crate::connection::{Tls, Unencrypted};  use crate::error::Error; +use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};  use crate::stanza::starttls::{Proceed, StartTls}; -use crate::stanza::stream::{Features, Stream}; +use crate::stanza::stream::{Feature, Features, Stream};  use crate::stanza::XML_VERSION; -use crate::Result;  use crate::JID; +use crate::{Connection, Result}; -pub struct Jabber<S> -where -    S: AsyncRead + AsyncWrite + Unpin, -{ +pub struct Jabber<S> {      reader: Reader<ReadHalf<S>>,      writer: Writer<WriteHalf<S>>,      jid: Option<JID>, @@ -56,7 +56,89 @@ where      S: AsyncRead + AsyncWrite + Unpin + Send,      Jabber<S>: std::fmt::Debug,  { -    // pub async fn negotiate(self) -> Result<Jabber<S>> {} +    pub async fn sasl( +        &mut self, +        mechanisms: Mechanisms, +        sasl_config: Arc<SASLConfig>, +    ) -> Result<()> { +        let sasl = SASLClient::new(sasl_config); +        let mut offered_mechs: Vec<&Mechname> = Vec::new(); +        for mechanism in &mechanisms.mechanisms { +            offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) +        } +        debug!("{:?}", offered_mechs); +        let mut session = sasl.start_suggested(&offered_mechs)?; +        let selected_mechanism = session.get_mechname().as_str().to_owned(); +        debug!("selected mech: {:?}", selected_mechanism); +        let mut data: Option<Vec<u8>> = None; + +        if !session.are_we_first() { +            // if not first mention the mechanism then get challenge data +            // mention mechanism +            let auth = Auth { +                mechanism: selected_mechanism, +                sasl_data: "=".to_string(), +            }; +            self.writer.write_full(&auth).await?; +            // get challenge data +            let challenge: Challenge = self.reader.read().await?; +            debug!("challenge: {:?}", challenge); +            data = Some((*challenge).as_bytes().to_vec()); +            debug!("we didn't go first"); +        } else { +            // if first, mention mechanism and send data +            let mut sasl_data = Vec::new(); +            session.step64(None, &mut sasl_data).unwrap(); +            let auth = Auth { +                mechanism: selected_mechanism, +                sasl_data: str::from_utf8(&sasl_data)?.to_string(), +            }; +            debug!("{:?}", auth); +            self.writer.write_full(&auth).await?; + +            let server_response: ServerResponse = self.reader.read().await?; +            debug!("server_response: {:#?}", server_response); +            match server_response { +                ServerResponse::Challenge(challenge) => { +                    data = Some((*challenge).as_bytes().to_vec()) +                } +                ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), +            } +            debug!("we went first"); +        } + +        // stepping the authentication exchange to completion +        if data != None { +            debug!("data: {:?}", data); +            let mut sasl_data = Vec::new(); +            while { +                // decide if need to send more data over +                let state = session +                    .step64(data.as_deref(), &mut sasl_data) +                    .expect("step errored!"); +                state.is_running() +            } { +                // While we aren't finished, receive more data from the other party +                let response = Response::new(str::from_utf8(&sasl_data)?.to_string()); +                debug!("response: {:?}", response); +                self.writer.write_full(&response).await?; + +                let server_response: ServerResponse = self.reader.read().await?; +                debug!("server_response: {:#?}", server_response); +                match server_response { +                    ServerResponse::Challenge(challenge) => { +                        data = Some((*challenge).as_bytes().to_vec()) +                    } +                    ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), +                } +            } +        } +        Ok(()) +    } + +    pub async fn bind(&mut self) -> Result<()> { +        todo!() +    }      #[instrument]      pub async fn start_stream(&mut self) -> Result<()> { @@ -76,6 +158,8 @@ where          let decl = self.reader.read_prolog().await?;          // receive stream element and validate +        let text = str::from_utf8(self.reader.buffer.data()).unwrap(); +        debug!("data: {}", text);          let stream: Stream = self.reader.read_start().await?;          debug!("got stream: {:?}", stream);          if let Some(from) = stream.from { @@ -98,6 +182,87 @@ where  }  impl Jabber<Unencrypted> { +    pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> { +        self.start_stream().await?; +        // TODO: timeout +        let features = self.get_features().await?.features; +        if let Some(Feature::StartTls(_)) = features +            .iter() +            .find(|feature| matches!(feature, Feature::StartTls(_s))) +        { +            let jabber = self.starttls().await?; +            let jabber = jabber.negotiate().await?; +            return Ok(jabber); +        } else { +            // TODO: better error +            return Err(Error::TlsRequired); +        } +    } + +    #[async_recursion] +    pub async fn negotiate_tls_optional(mut self) -> Result<Connection> { +        self.start_stream().await?; +        // TODO: timeout +        let features = self.get_features().await?.features; +        if let Some(Feature::StartTls(_)) = features +            .iter() +            .find(|feature| matches!(feature, Feature::StartTls(_s))) +        { +            let jabber = self.starttls().await?; +            let jabber = jabber.negotiate().await?; +            return Ok(Connection::Encrypted(jabber)); +        } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( +            self.auth.clone(), +            features +                .iter() +                .find(|feature| matches!(feature, Feature::Sasl(_))), +        ) { +            self.sasl(mechanisms.clone(), sasl_config).await?; +            let jabber = self.negotiate_tls_optional().await?; +            Ok(jabber) +        } else if let Some(Feature::Bind) = features +            .iter() +            .find(|feature| matches!(feature, Feature::Bind)) +        { +            self.bind().await?; +            Ok(Connection::Unencrypted(self)) +        } else { +            // TODO: better error +            return Err(Error::Negotiation); +        } +    } +} + +impl Jabber<Tls> { +    #[async_recursion] +    pub async fn negotiate(mut self) -> Result<Jabber<Tls>> { +        self.start_stream().await?; +        let features = self.get_features().await?.features; + +        if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = ( +            self.auth.clone(), +            features +                .iter() +                .find(|feature| matches!(feature, Feature::Sasl(_))), +        ) { +            // TODO: avoid clone +            self.sasl(mechanisms.clone(), sasl_config).await?; +            let jabber = self.negotiate().await?; +            Ok(jabber) +        } else if let Some(Feature::Bind) = features +            .iter() +            .find(|feature| matches!(feature, Feature::Bind)) +        { +            self.bind().await?; +            Ok(self) +        } else { +            // TODO: better error +            return Err(Error::Negotiation); +        } +    } +} + +impl Jabber<Unencrypted> {      pub async fn starttls(mut self) -> Result<Jabber<Tls>> {          self.writer              .write_full(&StartTls { required: false }) @@ -155,10 +320,47 @@ mod tests {      #[test(tokio::test)]      async fn start_stream() { -        let connection = Connection::connect("blos.sm").await.unwrap(); +        let connection = Connection::connect("blos.sm", None, None).await.unwrap();          match connection {              Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),              Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),          }      } + +    #[test(tokio::test)] +    async fn sasl() { +        let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string()) +            .await +            .unwrap() +            .ensure_tls() +            .await +            .unwrap(); +        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); +        println!("data: {}", text); +        jabber.start_stream().await.unwrap(); + +        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); +        println!("data: {}", text); +        jabber.reader.read_buf().await.unwrap(); +        let text = str::from_utf8(jabber.reader.buffer.data()).unwrap(); +        println!("data: {}", text); + +        let features = jabber.get_features().await.unwrap(); +        let (sasl_config, feature) = ( +            jabber.auth.clone().unwrap(), +            features +                .features +                .iter() +                .find(|feature| matches!(feature, Feature::Sasl(_))) +                .unwrap(), +        ); +        match feature { +            Feature::StartTls(_start_tls) => todo!(), +            Feature::Sasl(mechanisms) => { +                jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap(); +            } +            Feature::Bind => todo!(), +            Feature::Unknown => todo!(), +        } +    }  } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 8b13789..6ac4fc9 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -1 +1,170 @@ +use std::ops::Deref; +use peanuts::{ +    element::{FromElement, IntoElement}, +    DeserializeError, Element, +}; +use tracing::debug; + +pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl"; + +#[derive(Debug, Clone)] +pub struct Mechanisms { +    pub mechanisms: Vec<String>, +} + +impl FromElement for Mechanisms { +    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("mechanisms")?; +        element.check_namespace(XMLNS)?; +        debug!("getting mechanisms"); +        let mechanisms: Vec<Mechanism> = element.pop_children()?; +        debug!("gottting mechanisms"); +        let mechanisms = mechanisms +            .into_iter() +            .map(|Mechanism(mechanism)| mechanism) +            .collect(); +        debug!("gottting mechanisms"); + +        Ok(Mechanisms { mechanisms }) +    } +} + +impl IntoElement for Mechanisms { +    fn builder(&self) -> peanuts::element::ElementBuilder { +        Element::builder("mechanisms", Some(XMLNS)).push_children( +            self.mechanisms +                .iter() +                .map(|mechanism| Mechanism(mechanism.to_string())) +                .collect(), +        ) +    } +} + +pub struct Mechanism(String); + +impl FromElement for Mechanism { +    fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("mechanism")?; +        element.check_namespace(XMLNS)?; + +        let mechanism = element.pop_value()?; + +        Ok(Mechanism(mechanism)) +    } +} + +impl IntoElement for Mechanism { +    fn builder(&self) -> peanuts::element::ElementBuilder { +        Element::builder("mechanism", Some(XMLNS)).push_text(self.0.clone()) +    } +} + +impl Deref for Mechanism { +    type Target = str; + +    fn deref(&self) -> &Self::Target { +        &self.0 +    } +} + +#[derive(Debug)] +pub struct Auth { +    pub mechanism: String, +    pub sasl_data: String, +} + +impl IntoElement for Auth { +    fn builder(&self) -> peanuts::element::ElementBuilder { +        Element::builder("auth", Some(XMLNS)) +            .push_attribute("mechanism", self.mechanism.clone()) +            .push_text(self.sasl_data.clone()) +    } +} + +#[derive(Debug)] +pub struct Challenge(String); + +impl Deref for Challenge { +    type Target = str; + +    fn deref(&self) -> &Self::Target { +        &self.0 +    } +} + +impl FromElement for Challenge { +    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("challenge")?; +        element.check_namespace(XMLNS)?; + +        let sasl_data = element.value()?; + +        Ok(Challenge(sasl_data)) +    } +} + +#[derive(Debug)] +pub struct Success(String); + +impl Deref for Success { +    type Target = str; + +    fn deref(&self) -> &Self::Target { +        &self.0 +    } +} + +impl FromElement for Success { +    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("success")?; +        element.check_namespace(XMLNS)?; + +        let sasl_data = element.value()?; + +        Ok(Success(sasl_data)) +    } +} + +#[derive(Debug)] +pub enum ServerResponse { +    Challenge(Challenge), +    Success(Success), +} + +impl FromElement for ServerResponse { +    fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { +        match element.identify() { +            (Some(XMLNS), "challenge") => { +                Ok(ServerResponse::Challenge(Challenge::from_element(element)?)) +            } +            (Some(XMLNS), "success") => { +                Ok(ServerResponse::Success(Success::from_element(element)?)) +            } +            _ => Err(DeserializeError::UnexpectedElement(element)), +        } +    } +} + +#[derive(Debug)] +pub struct Response(String); + +impl Response { +    pub fn new(response: String) -> Self { +        Self(response) +    } +} + +impl Deref for Response { +    type Target = str; + +    fn deref(&self) -> &Self::Target { +        &self.0 +    } +} + +impl IntoElement for Response { +    fn builder(&self) -> peanuts::element::ElementBuilder { +        Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone()) +    } +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 40f6ba0..fecace5 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -3,9 +3,11 @@ use std::collections::{HashMap, HashSet};  use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration};  use peanuts::XML_NS;  use peanuts::{element::Name, Element}; +use tracing::debug;  use crate::{Error, JID}; +use super::sasl::{self, Mechanisms};  use super::starttls::{self, StartTls};  pub const XMLNS: &str = "http://etherx.jabber.org/streams"; @@ -92,32 +94,12 @@ impl<'s> Stream {  #[derive(Debug)]  pub struct Features { -    features: Vec<Feature>, +    pub features: Vec<Feature>,  }  impl IntoElement for Features {      fn builder(&self) -> ElementBuilder {          Element::builder("features", Some(XMLNS)).push_children(self.features.clone()) -        // let mut content = Vec::new(); -        // for feature in &self.features { -        //     match feature { -        //         Feature::StartTls(start_tls) => { -        //             content.push(Content::Element(start_tls.into_element())) -        //         } -        //         Feature::Sasl => {} -        //         Feature::Bind => {} -        //         Feature::Unknown => {} -        //     } -        // } -        // Element { -        //     name: Name { -        //         namespace: Some(XMLNS.to_string()), -        //         local_name: "features".to_string(), -        //     }, -        //     namespace_declaration_overrides: HashSet::new(), -        //     attributes: HashMap::new(), -        //     content, -        // }      }  } @@ -128,7 +110,9 @@ impl FromElement for Features {          element.check_namespace(XMLNS)?;          element.check_name("features")?; +        debug!("got features stanza");          let features = element.children()?; +        debug!("got features period");          Ok(Features { features })      } @@ -137,7 +121,7 @@ impl FromElement for Features {  #[derive(Debug, Clone)]  pub enum Feature {      StartTls(StartTls), -    Sasl, +    Sasl(Mechanisms),      Bind,      Unknown,  } @@ -146,7 +130,7 @@ impl IntoElement for Feature {      fn builder(&self) -> ElementBuilder {          match self {              Feature::StartTls(start_tls) => start_tls.builder(), -            Feature::Sasl => todo!(), +            Feature::Sasl(mechanisms) => mechanisms.builder(),              Feature::Bind => todo!(),              Feature::Unknown => todo!(),          } @@ -155,11 +139,21 @@ impl IntoElement for Feature {  impl FromElement for Feature {      fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { +        let identity = element.identify(); +        debug!("identity: {:?}", identity);          match element.identify() {              (Some(starttls::XMLNS), "starttls") => { +                debug!("identified starttls");                  Ok(Feature::StartTls(StartTls::from_element(element)?))              } -            _ => Ok(Feature::Unknown), +            (Some(sasl::XMLNS), "mechanisms") => { +                debug!("identified mechanisms"); +                Ok(Feature::Sasl(Mechanisms::from_element(element)?)) +            } +            _ => { +                debug!("identified unknown feature"); +                Ok(Feature::Unknown) +            }          }      }  } | 
