diff options
Diffstat (limited to '')
| -rw-r--r-- | Cargo.toml | 4 | ||||
| -rw-r--r-- | src/client/encrypted.rs | 189 | ||||
| -rw-r--r-- | src/client/unencrypted.rs | 12 | ||||
| -rw-r--r-- | src/error.rs | 39 | ||||
| -rw-r--r-- | src/jabber.rs | 28 | ||||
| -rw-r--r-- | src/jid/mod.rs | 35 | ||||
| -rw-r--r-- | src/lib.rs | 26 | ||||
| -rw-r--r-- | src/stanza/mod.rs | 1 | ||||
| -rw-r--r-- | src/stanza/sasl.rs | 32 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 7 | 
10 files changed, 329 insertions, 44 deletions
| @@ -8,7 +8,9 @@ edition = "2021"  [dependencies]  async-trait = "0.1.68" -quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] } +quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] } +# TODO: remove unneeded features +rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] }  serde = { version = "1.0.164", features = ["derive"] }  tokio = { version = "1.28", features = ["full"] }  tokio-native-tls = "0.3.1" diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index 08439b2..a4bf0d1 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -1,24 +1,35 @@ +use std::str; +  use quick_xml::{ +    de::Deserializer,      events::{BytesDecl, BytesStart, Event}, +    name::QName, +    se::Serializer,      Reader, Writer,  }; -use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use rsasl::prelude::{Mechname, SASLClient}; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf};  use tokio::net::TcpStream;  use tokio_native_tls::TlsStream; +use crate::stanza::{ +    sasl::{Auth, Challenge, Mechanisms}, +    stream::{StreamFeature, StreamFeatures}, +};  use crate::Jabber;  use crate::Result;  pub struct JabberClient<'j> {      reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, -    writer: Writer<WriteHalf<TlsStream<TcpStream>>>, +    writer: WriteHalf<TlsStream<TcpStream>>,      jabber: &'j mut Jabber<'j>,  }  impl<'j> JabberClient<'j> {      pub fn new(          reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, -        writer: Writer<WriteHalf<TlsStream<TcpStream>>>, +        writer: WriteHalf<TlsStream<TcpStream>>,          jabber: &'j mut Jabber<'j>,      ) -> Self {          Self { @@ -37,13 +48,9 @@ impl<'j> JabberClient<'j> {          stream_element.push_attribute(("xml:lang", "en"));          stream_element.push_attribute(("xmlns", "jabber:client"));          stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); -        self.writer -            .write_event_async(Event::Decl(declaration)) -            .await; -        self.writer -            .write_event_async(Event::Start(stream_element)) -            .await -            .unwrap(); +        let mut writer = Writer::new(&mut self.writer); +        writer.write_event_async(Event::Decl(declaration)).await; +        writer.write_event_async(Event::Start(stream_element)).await;          let mut buf = Vec::new();          loop {              match self.reader.read_event_into_async(&mut buf).await.unwrap() { @@ -56,4 +63,166 @@ impl<'j> JabberClient<'j> {          }          Ok(())      } + +    pub async fn get_node<'a>(&mut self) -> Result<String> { +        let mut buf = Vec::new(); +        let mut txt = Vec::new(); +        let mut qname_set = false; +        let mut qname: Option<Vec<u8>> = None; +        loop { +            match self.reader.read_event_into_async(&mut buf).await? { +                Event::Start(e) => { +                    if !qname_set { +                        qname = Some(e.name().into_inner().to_owned()); +                        qname_set = true; +                    } +                    txt.push(b'<'); +                    txt = txt +                        .into_iter() +                        .chain(buf.to_owned()) +                        .chain(vec![b'>']) +                        .collect(); +                } +                Event::End(e) => { +                    let mut end = false; +                    if e.name() == QName(qname.as_deref().unwrap()) { +                        end = true; +                    } +                    txt.push(b'<'); +                    txt = txt +                        .into_iter() +                        .chain(buf.to_owned()) +                        .chain(vec![b'>']) +                        .collect(); +                    if end { +                        break; +                    } +                } +                Event::Text(_e) => { +                    txt = txt.into_iter().chain(buf.to_owned()).collect(); +                } +                _ => { +                    txt.push(b'<'); +                    txt = txt +                        .into_iter() +                        .chain(buf.to_owned()) +                        .chain(vec![b'>']) +                        .collect(); +                } +            } +            buf.clear(); +        } +        println!("{:?}", txt); +        let decoded = str::from_utf8(&txt)?.to_owned(); +        println!("{:?}", decoded); +        Ok(decoded) +    } + +    pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { +        let node = self.get_node().await?; +        let mut deserializer = Deserializer::from_str(&node); +        let features = StreamFeatures::deserialize(&mut deserializer).unwrap(); +        println!("{:?}", features); +        Ok(features.features) +    } + +    pub async fn negotiate(&mut self) -> Result<()> { +        loop { +            println!("loop"); +            let features = &self.get_features().await?; +            println!("{:?}", features); +            match &features[0] { +                StreamFeature::Sasl(sasl) => { +                    println!("{:?}", sasl); +                    self.sasl(&sasl).await?; +                } +                StreamFeature::Bind => todo!(), +                x => println!("{:?}", x), +            } +        } +    } + +    pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> { +        println!("{:?}", mechanisms); +        let sasl = SASLClient::new(self.jabber.auth.clone()); +        let mut offered_mechs: Vec<&Mechname> = Vec::new(); +        for mechanism in &mechanisms.mechanisms { +            offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?) +        } +        println!("{:?}", offered_mechs); +        let mut session = sasl.start_suggested(&offered_mechs)?; +        let selected_mechanism = session.get_mechname().as_str().to_owned(); +        println!("selected mech: {:?}", selected_mechanism); +        let mut data: Option<Vec<u8>> = None; +        if !session.are_we_first() { +            // if not first mention the mechanism then get challenge data +            // mention mechanism +            let auth = Auth { +                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), +                mechanism: selected_mechanism.clone(), +                sasl_data: Some("=".to_owned()), +            }; +            let mut buffer = String::new(); +            let ser = Serializer::new(&mut buffer); +            auth.serialize(ser).unwrap(); +            self.writer.write_all(buffer.as_bytes()); +            // get challenge data +            let node = self.get_node().await?; +            let mut deserializer = Deserializer::from_str(&node); +            let challenge = Challenge::deserialize(&mut deserializer).unwrap(); +            println!("challenge: {:?}", challenge); +            data = Some(challenge.sasl_data.as_bytes().to_owned()); +            println!("we didn't go first"); +        } else { +            // if first, mention mechanism and send data +            let mut sasl_data = Vec::new(); +            session.step64(None, &mut sasl_data).unwrap(); +            let auth = Auth { +                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), +                mechanism: selected_mechanism.clone(), +                sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()), +            }; +            let mut buffer = String::new(); +            let ser = Serializer::new(&mut buffer); +            auth.serialize(ser).unwrap(); +            println!("node: {:?}", buffer); +            self.writer.write_all(buffer.as_bytes()).await; +            println!("we went first"); +            // get challenge data +            // TODO: check if needed +            // let node = self.get_node().await?; +            // println!("node: {:?}", node); +            // let mut deserializer = Deserializer::from_str(&node); +            // let challenge = Challenge::deserialize(&mut deserializer).unwrap(); +            // println!("challenge: {:?}", challenge); +            // data = Some(challenge.sasl_data.as_bytes().to_owned()); +        } + +        // stepping the authentication exchange to completion +        let mut sasl_data = Vec::new(); +        while { +            // decide if need to send more data over +            let state = session +                .step64(data.as_deref(), &mut sasl_data) +                .expect("step errored!"); +            state.is_running() +        } { +            // While we aren't finished, receive more data from the other party +            let auth = Auth { +                ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), +                mechanism: selected_mechanism.clone(), +                sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()), +            }; +            let mut buffer = String::new(); +            let ser = Serializer::new(&mut buffer); +            auth.serialize(ser).unwrap(); +            self.writer.write_all(buffer.as_bytes()); +            let node = self.get_node().await?; +            let mut deserializer = Deserializer::from_str(&node); +            let challenge = Challenge::deserialize(&mut deserializer).unwrap(); +            data = Some(challenge.sasl_data.as_bytes().to_owned()); +        } +        self.start_stream().await?; +        Ok(()) +    }  } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs index 74b800c..d4225d3 100644 --- a/src/client/unencrypted.rs +++ b/src/client/unencrypted.rs @@ -115,14 +115,12 @@ impl<'j> JabberClient<'j> {                          .connect(&self.jabber.server, stream)                          .await                      { -                        let (read, write) = tokio::io::split(tlsstream); +                        let (read, writer) = tokio::io::split(tlsstream);                          let reader = Reader::from_reader(BufReader::new(read)); -                        let writer = Writer::new(write); -                        return Ok(super::encrypted::JabberClient::new( -                            reader, -                            writer, -                            self.jabber, -                        )); +                        let mut client = +                            super::encrypted::JabberClient::new(reader, writer, self.jabber); +                        client.start_stream().await?; +                        return Ok(client);                      }                  }                  QName(_) => return Err(JabberError::TlsNegotiation), diff --git a/src/error.rs b/src/error.rs index a632537..20ebc3e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,44 @@ +use std::str::Utf8Error; + +use rsasl::mechname::MechanismNameError; +  #[derive(Debug)]  pub enum JabberError { -    ConnectionError, +    Connection,      BadStream,      StartTlsUnavailable,      TlsNegotiation, +    Utf8Decode, +    XML(quick_xml::Error), +    SASL(SASLError), +} + +#[derive(Debug)] +pub enum SASLError { +    SASL(rsasl::prelude::SASLError), +    MechanismName(MechanismNameError), +} + +impl From<rsasl::prelude::SASLError> for JabberError { +    fn from(e: rsasl::prelude::SASLError) -> Self { +        Self::SASL(SASLError::SASL(e)) +    } +} + +impl From<MechanismNameError> for JabberError { +    fn from(value: MechanismNameError) -> Self { +        Self::SASL(SASLError::MechanismName(value)) +    } +} + +impl From<Utf8Error> for JabberError { +    fn from(e: Utf8Error) -> Self { +        Self::Utf8Decode +    } +} + +impl From<quick_xml::Error> for JabberError { +    fn from(e: quick_xml::Error) -> Self { +        Self::XML(e) +    }  } diff --git a/src/jabber.rs b/src/jabber.rs index a1f6272..a1b2a2f 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,33 +1,44 @@  use std::marker::PhantomData;  use std::net::{IpAddr, SocketAddr};  use std::str::FromStr; +use std::sync::Arc;  use quick_xml::{Reader, Writer}; +use rsasl::prelude::SASLConfig;  use tokio::io::BufReader;  use tokio::net::TcpStream;  use tokio_native_tls::native_tls::TlsConnector; -use crate::client;  use crate::client::JabberClientType;  use crate::jid::JID; +use crate::{client, JabberClient};  use crate::{JabberError, Result};  pub struct Jabber<'j> {      pub jid: JID, -    pub password: String, +    pub auth: Arc<SASLConfig>,      pub server: String,      _marker: PhantomData<&'j ()>,  }  impl<'j> Jabber<'j> { -    pub fn new(jid: JID, password: String) -> Self { +    pub fn new(jid: JID, password: String) -> Result<Self> {          let server = jid.domainpart.clone(); -        Self { +        let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?; +        println!("auth: {:?}", auth); +        Ok(Self {              jid, -            password, +            auth,              server,              _marker: PhantomData, -        } +        }) +    } + +    pub async fn login(&'j mut self) -> Result<JabberClient<'j>> { +        let mut client = self.connect().await?.ensure_tls().await?; +        println!("negotiation"); +        client.negotiate().await?; +        Ok(client)      }      async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> { @@ -106,9 +117,8 @@ impl<'j> Jabber<'j> {                          .connect(&self.server, socket)                          .await                      { -                        let (read, write) = tokio::io::split(stream); +                        let (read, writer) = tokio::io::split(stream);                          let reader = Reader::from_reader(BufReader::new(read)); -                        let writer = Writer::new(write);                          return Ok(JabberClientType::Encrypted(                              client::encrypted::JabberClient::new(reader, writer, self),                          )); @@ -126,6 +136,6 @@ impl<'j> Jabber<'j> {                  }              }          } -        Err(JabberError::ConnectionError) +        Err(JabberError::Connection)      }  } diff --git a/src/jid/mod.rs b/src/jid/mod.rs index 4baa857..b2a03ea 100644 --- a/src/jid/mod.rs +++ b/src/jid/mod.rs @@ -8,8 +8,13 @@ pub struct JID {      pub resourcepart: Option<String>,  } +pub enum JIDError { +    NoResourcePart, +    ParseError(ParseError), +} +  #[derive(Debug)] -pub enum JIDParseError { +pub enum ParseError {      Empty,      Malformed,  } @@ -26,15 +31,31 @@ impl JID {              resourcepart,          }      } + +    pub fn as_bare(&self) -> Self { +        Self { +            localpart: self.localpart.clone(), +            domainpart: self.domainpart.clone(), +            resourcepart: None, +        } +    } + +    pub fn as_full(&self) -> Result<&Self, JIDError> { +        if let Some(_) = self.resourcepart { +            Ok(&self) +        } else { +            Err(JIDError::NoResourcePart) +        } +    }  }  impl FromStr for JID { -    type Err = JIDParseError; +    type Err = ParseError;      fn from_str(s: &str) -> Result<Self, Self::Err> {          let split: Vec<&str> = s.split('@').collect();          match split.len() { -            0 => Err(JIDParseError::Empty), +            0 => Err(ParseError::Empty),              1 => {                  let split: Vec<&str> = split[0].split('/').collect();                  match split.len() { @@ -44,7 +65,7 @@ impl FromStr for JID {                          split[0].to_string(),                          Some(split[1].to_string()),                      )), -                    _ => Err(JIDParseError::Malformed), +                    _ => Err(ParseError::Malformed),                  }              }              2 => { @@ -60,16 +81,16 @@ impl FromStr for JID {                          split2[0].to_string(),                          Some(split2[1].to_string()),                      )), -                    _ => Err(JIDParseError::Malformed), +                    _ => Err(ParseError::Malformed),                  }              } -            _ => Err(JIDParseError::Malformed), +            _ => Err(ParseError::Malformed),          }      }  }  impl TryFrom<String> for JID { -    type Error = JIDParseError; +    type Error = ParseError;      fn try_from(value: String) -> Result<Self, Self::Error> {          value.parse() @@ -27,16 +27,26 @@ mod tests {      //     println!("{:?}", jabber.get_sockets().await)      // } +    // #[tokio::test] +    // async fn connect() { +    //     Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) +    //         .unwrap() +    //         .connect() +    //         .await +    //         .unwrap() +    //         .ensure_tls() +    //         .await +    //         .unwrap() +    //         .start_stream() +    //         .await +    //         .unwrap(); +    // } +      #[tokio::test] -    async fn connect() { -        Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) -            .connect() -            .await -            .unwrap() -            .ensure_tls() -            .await +    async fn login() { +        Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned())              .unwrap() -            .start_stream() +            .login()              .await              .unwrap();      } diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index baf29e0..4eaa4c2 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1 +1,2 @@ +pub mod sasl;  pub mod stream; diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs new file mode 100644 index 0000000..c0e41ab --- /dev/null +++ b/src/stanza/sasl.rs @@ -0,0 +1,32 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, PartialEq, Debug)] +pub struct Mechanisms { +    #[serde(rename = "$value")] +    pub mechanisms: Vec<Mechanism>, +} + +#[derive(Deserialize, PartialEq, Debug)] +pub struct Mechanism { +    #[serde(rename = "$text")] +    pub mechanism: String, +} + +#[derive(Serialize, Debug)] +#[serde(rename = "auth")] +pub struct Auth { +    #[serde(rename = "@xmlns")] +    pub ns: String, +    #[serde(rename = "@mechanism")] +    pub mechanism: String, +    #[serde(rename = "$text")] +    pub sasl_data: Option<String>, +} + +#[derive(Deserialize, Debug)] +pub struct Challenge { +    #[serde(rename = "@xmlns")] +    pub ns: String, +    #[serde(rename = "$text")] +    pub sasl_data: String, +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index dde741d..4c0addd 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,5 +1,7 @@  use serde::{Deserialize, Serialize}; +use super::sasl::Mechanisms; +  #[derive(Serialize, Deserialize)]  #[serde(rename = "stream:stream")]  struct Stream { @@ -31,6 +33,9 @@ pub enum StreamFeature {      #[serde(rename = "starttls")]      StartTls,      // TODO: other stream features -    Sasl, +    #[serde(rename = "mechanisms")] +    Sasl(Mechanisms),      Bind, +    #[serde(other)] +    Unknown,  } | 
