diff options
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | src/client/encrypted.rs | 189 | ||||
-rw-r--r-- | src/client/unencrypted.rs | 12 | ||||
-rw-r--r-- | src/error.rs | 39 | ||||
-rw-r--r-- | src/jabber.rs | 28 | ||||
-rw-r--r-- | src/jid/mod.rs | 35 | ||||
-rw-r--r-- | src/lib.rs | 26 | ||||
-rw-r--r-- | src/stanza/mod.rs | 1 | ||||
-rw-r--r-- | src/stanza/sasl.rs | 32 | ||||
-rw-r--r-- | src/stanza/stream.rs | 7 |
10 files changed, 329 insertions, 44 deletions
@@ -8,7 +8,9 @@ edition = "2021" [dependencies] async-trait = "0.1.68" -quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] } +quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] } +# TODO: remove unneeded features +rsasl = { version = "2", default_features = false, features = ["provider_base64", "plain", "config_builder"] } serde = { version = "1.0.164", features = ["derive"] } tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs index 08439b2..a4bf0d1 100644 --- a/src/client/encrypted.rs +++ b/src/client/encrypted.rs @@ -1,24 +1,35 @@ +use std::str; + use quick_xml::{ + de::Deserializer, events::{BytesDecl, BytesStart, Event}, + name::QName, + se::Serializer, Reader, Writer, }; -use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use rsasl::prelude::{Mechname, SASLClient}; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use tokio_native_tls::TlsStream; +use crate::stanza::{ + sasl::{Auth, Challenge, Mechanisms}, + stream::{StreamFeature, StreamFeatures}, +}; use crate::Jabber; use crate::Result; pub struct JabberClient<'j> { reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, - writer: Writer<WriteHalf<TlsStream<TcpStream>>>, + writer: WriteHalf<TlsStream<TcpStream>>, jabber: &'j mut Jabber<'j>, } impl<'j> JabberClient<'j> { pub fn new( reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, - writer: Writer<WriteHalf<TlsStream<TcpStream>>>, + writer: WriteHalf<TlsStream<TcpStream>>, jabber: &'j mut Jabber<'j>, ) -> Self { Self { @@ -37,13 +48,9 @@ impl<'j> JabberClient<'j> { stream_element.push_attribute(("xml:lang", "en")); stream_element.push_attribute(("xmlns", "jabber:client")); stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); - self.writer - .write_event_async(Event::Decl(declaration)) - .await; - self.writer - .write_event_async(Event::Start(stream_element)) - .await - .unwrap(); + let mut writer = Writer::new(&mut self.writer); + writer.write_event_async(Event::Decl(declaration)).await; + writer.write_event_async(Event::Start(stream_element)).await; let mut buf = Vec::new(); loop { match self.reader.read_event_into_async(&mut buf).await.unwrap() { @@ -56,4 +63,166 @@ impl<'j> JabberClient<'j> { } Ok(()) } + + pub async fn get_node<'a>(&mut self) -> Result<String> { + let mut buf = Vec::new(); + let mut txt = Vec::new(); + let mut qname_set = false; + let mut qname: Option<Vec<u8>> = None; + loop { + match self.reader.read_event_into_async(&mut buf).await? { + Event::Start(e) => { + if !qname_set { + qname = Some(e.name().into_inner().to_owned()); + qname_set = true; + } + txt.push(b'<'); + txt = txt + .into_iter() + .chain(buf.to_owned()) + .chain(vec![b'>']) + .collect(); + } + Event::End(e) => { + let mut end = false; + if e.name() == QName(qname.as_deref().unwrap()) { + end = true; + } + txt.push(b'<'); + txt = txt + .into_iter() + .chain(buf.to_owned()) + .chain(vec![b'>']) + .collect(); + if end { + break; + } + } + Event::Text(_e) => { + txt = txt.into_iter().chain(buf.to_owned()).collect(); + } + _ => { + txt.push(b'<'); + txt = txt + .into_iter() + .chain(buf.to_owned()) + .chain(vec![b'>']) + .collect(); + } + } + buf.clear(); + } + println!("{:?}", txt); + let decoded = str::from_utf8(&txt)?.to_owned(); + println!("{:?}", decoded); + Ok(decoded) + } + + pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { + let node = self.get_node().await?; + let mut deserializer = Deserializer::from_str(&node); + let features = StreamFeatures::deserialize(&mut deserializer).unwrap(); + println!("{:?}", features); + Ok(features.features) + } + + pub async fn negotiate(&mut self) -> Result<()> { + loop { + println!("loop"); + let features = &self.get_features().await?; + println!("{:?}", features); + match &features[0] { + StreamFeature::Sasl(sasl) => { + println!("{:?}", sasl); + self.sasl(&sasl).await?; + } + StreamFeature::Bind => todo!(), + x => println!("{:?}", x), + } + } + } + + pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> { + println!("{:?}", mechanisms); + let sasl = SASLClient::new(self.jabber.auth.clone()); + let mut offered_mechs: Vec<&Mechname> = Vec::new(); + for mechanism in &mechanisms.mechanisms { + offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?) + } + println!("{:?}", offered_mechs); + let mut session = sasl.start_suggested(&offered_mechs)?; + let selected_mechanism = session.get_mechname().as_str().to_owned(); + println!("selected mech: {:?}", selected_mechanism); + let mut data: Option<Vec<u8>> = None; + if !session.are_we_first() { + // if not first mention the mechanism then get challenge data + // mention mechanism + let auth = Auth { + ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), + mechanism: selected_mechanism.clone(), + sasl_data: Some("=".to_owned()), + }; + let mut buffer = String::new(); + let ser = Serializer::new(&mut buffer); + auth.serialize(ser).unwrap(); + self.writer.write_all(buffer.as_bytes()); + // get challenge data + let node = self.get_node().await?; + let mut deserializer = Deserializer::from_str(&node); + let challenge = Challenge::deserialize(&mut deserializer).unwrap(); + println!("challenge: {:?}", challenge); + data = Some(challenge.sasl_data.as_bytes().to_owned()); + println!("we didn't go first"); + } else { + // if first, mention mechanism and send data + let mut sasl_data = Vec::new(); + session.step64(None, &mut sasl_data).unwrap(); + let auth = Auth { + ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), + mechanism: selected_mechanism.clone(), + sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()), + }; + let mut buffer = String::new(); + let ser = Serializer::new(&mut buffer); + auth.serialize(ser).unwrap(); + println!("node: {:?}", buffer); + self.writer.write_all(buffer.as_bytes()).await; + println!("we went first"); + // get challenge data + // TODO: check if needed + // let node = self.get_node().await?; + // println!("node: {:?}", node); + // let mut deserializer = Deserializer::from_str(&node); + // let challenge = Challenge::deserialize(&mut deserializer).unwrap(); + // println!("challenge: {:?}", challenge); + // data = Some(challenge.sasl_data.as_bytes().to_owned()); + } + + // stepping the authentication exchange to completion + let mut sasl_data = Vec::new(); + while { + // decide if need to send more data over + let state = session + .step64(data.as_deref(), &mut sasl_data) + .expect("step errored!"); + state.is_running() + } { + // While we aren't finished, receive more data from the other party + let auth = Auth { + ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(), + mechanism: selected_mechanism.clone(), + sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()), + }; + let mut buffer = String::new(); + let ser = Serializer::new(&mut buffer); + auth.serialize(ser).unwrap(); + self.writer.write_all(buffer.as_bytes()); + let node = self.get_node().await?; + let mut deserializer = Deserializer::from_str(&node); + let challenge = Challenge::deserialize(&mut deserializer).unwrap(); + data = Some(challenge.sasl_data.as_bytes().to_owned()); + } + self.start_stream().await?; + Ok(()) + } } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs index 74b800c..d4225d3 100644 --- a/src/client/unencrypted.rs +++ b/src/client/unencrypted.rs @@ -115,14 +115,12 @@ impl<'j> JabberClient<'j> { .connect(&self.jabber.server, stream) .await { - let (read, write) = tokio::io::split(tlsstream); + let (read, writer) = tokio::io::split(tlsstream); let reader = Reader::from_reader(BufReader::new(read)); - let writer = Writer::new(write); - return Ok(super::encrypted::JabberClient::new( - reader, - writer, - self.jabber, - )); + let mut client = + super::encrypted::JabberClient::new(reader, writer, self.jabber); + client.start_stream().await?; + return Ok(client); } } QName(_) => return Err(JabberError::TlsNegotiation), diff --git a/src/error.rs b/src/error.rs index a632537..20ebc3e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,44 @@ +use std::str::Utf8Error; + +use rsasl::mechname::MechanismNameError; + #[derive(Debug)] pub enum JabberError { - ConnectionError, + Connection, BadStream, StartTlsUnavailable, TlsNegotiation, + Utf8Decode, + XML(quick_xml::Error), + SASL(SASLError), +} + +#[derive(Debug)] +pub enum SASLError { + SASL(rsasl::prelude::SASLError), + MechanismName(MechanismNameError), +} + +impl From<rsasl::prelude::SASLError> for JabberError { + fn from(e: rsasl::prelude::SASLError) -> Self { + Self::SASL(SASLError::SASL(e)) + } +} + +impl From<MechanismNameError> for JabberError { + fn from(value: MechanismNameError) -> Self { + Self::SASL(SASLError::MechanismName(value)) + } +} + +impl From<Utf8Error> for JabberError { + fn from(e: Utf8Error) -> Self { + Self::Utf8Decode + } +} + +impl From<quick_xml::Error> for JabberError { + fn from(e: quick_xml::Error) -> Self { + Self::XML(e) + } } diff --git a/src/jabber.rs b/src/jabber.rs index a1f6272..a1b2a2f 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,33 +1,44 @@ use std::marker::PhantomData; use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; +use std::sync::Arc; use quick_xml::{Reader, Writer}; +use rsasl::prelude::SASLConfig; use tokio::io::BufReader; use tokio::net::TcpStream; use tokio_native_tls::native_tls::TlsConnector; -use crate::client; use crate::client::JabberClientType; use crate::jid::JID; +use crate::{client, JabberClient}; use crate::{JabberError, Result}; pub struct Jabber<'j> { pub jid: JID, - pub password: String, + pub auth: Arc<SASLConfig>, pub server: String, _marker: PhantomData<&'j ()>, } impl<'j> Jabber<'j> { - pub fn new(jid: JID, password: String) -> Self { + pub fn new(jid: JID, password: String) -> Result<Self> { let server = jid.domainpart.clone(); - Self { + let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?; + println!("auth: {:?}", auth); + Ok(Self { jid, - password, + auth, server, _marker: PhantomData, - } + }) + } + + pub async fn login(&'j mut self) -> Result<JabberClient<'j>> { + let mut client = self.connect().await?.ensure_tls().await?; + println!("negotiation"); + client.negotiate().await?; + Ok(client) } async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> { @@ -106,9 +117,8 @@ impl<'j> Jabber<'j> { .connect(&self.server, socket) .await { - let (read, write) = tokio::io::split(stream); + let (read, writer) = tokio::io::split(stream); let reader = Reader::from_reader(BufReader::new(read)); - let writer = Writer::new(write); return Ok(JabberClientType::Encrypted( client::encrypted::JabberClient::new(reader, writer, self), )); @@ -126,6 +136,6 @@ impl<'j> Jabber<'j> { } } } - Err(JabberError::ConnectionError) + Err(JabberError::Connection) } } diff --git a/src/jid/mod.rs b/src/jid/mod.rs index 4baa857..b2a03ea 100644 --- a/src/jid/mod.rs +++ b/src/jid/mod.rs @@ -8,8 +8,13 @@ pub struct JID { pub resourcepart: Option<String>, } +pub enum JIDError { + NoResourcePart, + ParseError(ParseError), +} + #[derive(Debug)] -pub enum JIDParseError { +pub enum ParseError { Empty, Malformed, } @@ -26,15 +31,31 @@ impl JID { resourcepart, } } + + pub fn as_bare(&self) -> Self { + Self { + localpart: self.localpart.clone(), + domainpart: self.domainpart.clone(), + resourcepart: None, + } + } + + pub fn as_full(&self) -> Result<&Self, JIDError> { + if let Some(_) = self.resourcepart { + Ok(&self) + } else { + Err(JIDError::NoResourcePart) + } + } } impl FromStr for JID { - type Err = JIDParseError; + type Err = ParseError; fn from_str(s: &str) -> Result<Self, Self::Err> { let split: Vec<&str> = s.split('@').collect(); match split.len() { - 0 => Err(JIDParseError::Empty), + 0 => Err(ParseError::Empty), 1 => { let split: Vec<&str> = split[0].split('/').collect(); match split.len() { @@ -44,7 +65,7 @@ impl FromStr for JID { split[0].to_string(), Some(split[1].to_string()), )), - _ => Err(JIDParseError::Malformed), + _ => Err(ParseError::Malformed), } } 2 => { @@ -60,16 +81,16 @@ impl FromStr for JID { split2[0].to_string(), Some(split2[1].to_string()), )), - _ => Err(JIDParseError::Malformed), + _ => Err(ParseError::Malformed), } } - _ => Err(JIDParseError::Malformed), + _ => Err(ParseError::Malformed), } } } impl TryFrom<String> for JID { - type Error = JIDParseError; + type Error = ParseError; fn try_from(value: String) -> Result<Self, Self::Error> { value.parse() @@ -27,16 +27,26 @@ mod tests { // println!("{:?}", jabber.get_sockets().await) // } + // #[tokio::test] + // async fn connect() { + // Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) + // .unwrap() + // .connect() + // .await + // .unwrap() + // .ensure_tls() + // .await + // .unwrap() + // .start_stream() + // .await + // .unwrap(); + // } + #[tokio::test] - async fn connect() { - Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) - .connect() - .await - .unwrap() - .ensure_tls() - .await + async fn login() { + Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned()) .unwrap() - .start_stream() + .login() .await .unwrap(); } diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index baf29e0..4eaa4c2 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1 +1,2 @@ +pub mod sasl; pub mod stream; diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs new file mode 100644 index 0000000..c0e41ab --- /dev/null +++ b/src/stanza/sasl.rs @@ -0,0 +1,32 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, PartialEq, Debug)] +pub struct Mechanisms { + #[serde(rename = "$value")] + pub mechanisms: Vec<Mechanism>, +} + +#[derive(Deserialize, PartialEq, Debug)] +pub struct Mechanism { + #[serde(rename = "$text")] + pub mechanism: String, +} + +#[derive(Serialize, Debug)] +#[serde(rename = "auth")] +pub struct Auth { + #[serde(rename = "@xmlns")] + pub ns: String, + #[serde(rename = "@mechanism")] + pub mechanism: String, + #[serde(rename = "$text")] + pub sasl_data: Option<String>, +} + +#[derive(Deserialize, Debug)] +pub struct Challenge { + #[serde(rename = "@xmlns")] + pub ns: String, + #[serde(rename = "$text")] + pub sasl_data: String, +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index dde741d..4c0addd 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use super::sasl::Mechanisms; + #[derive(Serialize, Deserialize)] #[serde(rename = "stream:stream")] struct Stream { @@ -31,6 +33,9 @@ pub enum StreamFeature { #[serde(rename = "starttls")] StartTls, // TODO: other stream features - Sasl, + #[serde(rename = "mechanisms")] + Sasl(Mechanisms), Bind, + #[serde(other)] + Unknown, } |