summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/client/encrypted.rs189
-rw-r--r--src/client/unencrypted.rs12
-rw-r--r--src/error.rs39
-rw-r--r--src/jabber.rs28
-rw-r--r--src/jid/mod.rs35
-rw-r--r--src/lib.rs26
-rw-r--r--src/stanza/mod.rs1
-rw-r--r--src/stanza/sasl.rs32
-rw-r--r--src/stanza/stream.rs7
9 files changed, 326 insertions, 43 deletions
diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs
index 08439b2..a4bf0d1 100644
--- a/src/client/encrypted.rs
+++ b/src/client/encrypted.rs
@@ -1,24 +1,35 @@
+use std::str;
+
use quick_xml::{
+ de::Deserializer,
events::{BytesDecl, BytesStart, Event},
+ name::QName,
+ se::Serializer,
Reader, Writer,
};
-use tokio::io::{BufReader, ReadHalf, WriteHalf};
+use rsasl::prelude::{Mechname, SASLClient};
+use serde::{Deserialize, Serialize};
+use tokio::io::{AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsStream;
+use crate::stanza::{
+ sasl::{Auth, Challenge, Mechanisms},
+ stream::{StreamFeature, StreamFeatures},
+};
use crate::Jabber;
use crate::Result;
pub struct JabberClient<'j> {
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
- writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
+ writer: WriteHalf<TlsStream<TcpStream>>,
jabber: &'j mut Jabber<'j>,
}
impl<'j> JabberClient<'j> {
pub fn new(
reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
- writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
+ writer: WriteHalf<TlsStream<TcpStream>>,
jabber: &'j mut Jabber<'j>,
) -> Self {
Self {
@@ -37,13 +48,9 @@ impl<'j> JabberClient<'j> {
stream_element.push_attribute(("xml:lang", "en"));
stream_element.push_attribute(("xmlns", "jabber:client"));
stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
- self.writer
- .write_event_async(Event::Decl(declaration))
- .await;
- self.writer
- .write_event_async(Event::Start(stream_element))
- .await
- .unwrap();
+ let mut writer = Writer::new(&mut self.writer);
+ writer.write_event_async(Event::Decl(declaration)).await;
+ writer.write_event_async(Event::Start(stream_element)).await;
let mut buf = Vec::new();
loop {
match self.reader.read_event_into_async(&mut buf).await.unwrap() {
@@ -56,4 +63,166 @@ impl<'j> JabberClient<'j> {
}
Ok(())
}
+
+ pub async fn get_node<'a>(&mut self) -> Result<String> {
+ let mut buf = Vec::new();
+ let mut txt = Vec::new();
+ let mut qname_set = false;
+ let mut qname: Option<Vec<u8>> = None;
+ loop {
+ match self.reader.read_event_into_async(&mut buf).await? {
+ Event::Start(e) => {
+ if !qname_set {
+ qname = Some(e.name().into_inner().to_owned());
+ qname_set = true;
+ }
+ txt.push(b'<');
+ txt = txt
+ .into_iter()
+ .chain(buf.to_owned())
+ .chain(vec![b'>'])
+ .collect();
+ }
+ Event::End(e) => {
+ let mut end = false;
+ if e.name() == QName(qname.as_deref().unwrap()) {
+ end = true;
+ }
+ txt.push(b'<');
+ txt = txt
+ .into_iter()
+ .chain(buf.to_owned())
+ .chain(vec![b'>'])
+ .collect();
+ if end {
+ break;
+ }
+ }
+ Event::Text(_e) => {
+ txt = txt.into_iter().chain(buf.to_owned()).collect();
+ }
+ _ => {
+ txt.push(b'<');
+ txt = txt
+ .into_iter()
+ .chain(buf.to_owned())
+ .chain(vec![b'>'])
+ .collect();
+ }
+ }
+ buf.clear();
+ }
+ println!("{:?}", txt);
+ let decoded = str::from_utf8(&txt)?.to_owned();
+ println!("{:?}", decoded);
+ Ok(decoded)
+ }
+
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ let node = self.get_node().await?;
+ let mut deserializer = Deserializer::from_str(&node);
+ let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
+ println!("{:?}", features);
+ Ok(features.features)
+ }
+
+ pub async fn negotiate(&mut self) -> Result<()> {
+ loop {
+ println!("loop");
+ let features = &self.get_features().await?;
+ println!("{:?}", features);
+ match &features[0] {
+ StreamFeature::Sasl(sasl) => {
+ println!("{:?}", sasl);
+ self.sasl(&sasl).await?;
+ }
+ StreamFeature::Bind => todo!(),
+ x => println!("{:?}", x),
+ }
+ }
+ }
+
+ pub async fn sasl(&mut self, mechanisms: &Mechanisms) -> Result<()> {
+ println!("{:?}", mechanisms);
+ let sasl = SASLClient::new(self.jabber.auth.clone());
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in &mechanisms.mechanisms {
+ offered_mechs.push(Mechname::parse(&mechanism.mechanism.as_bytes())?)
+ }
+ println!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ println!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
+ mechanism: selected_mechanism.clone(),
+ sasl_data: Some("=".to_owned()),
+ };
+ let mut buffer = String::new();
+ let ser = Serializer::new(&mut buffer);
+ auth.serialize(ser).unwrap();
+ self.writer.write_all(buffer.as_bytes());
+ // get challenge data
+ let node = self.get_node().await?;
+ let mut deserializer = Deserializer::from_str(&node);
+ let challenge = Challenge::deserialize(&mut deserializer).unwrap();
+ println!("challenge: {:?}", challenge);
+ data = Some(challenge.sasl_data.as_bytes().to_owned());
+ println!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
+ mechanism: selected_mechanism.clone(),
+ sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
+ };
+ let mut buffer = String::new();
+ let ser = Serializer::new(&mut buffer);
+ auth.serialize(ser).unwrap();
+ println!("node: {:?}", buffer);
+ self.writer.write_all(buffer.as_bytes()).await;
+ println!("we went first");
+ // get challenge data
+ // TODO: check if needed
+ // let node = self.get_node().await?;
+ // println!("node: {:?}", node);
+ // let mut deserializer = Deserializer::from_str(&node);
+ // let challenge = Challenge::deserialize(&mut deserializer).unwrap();
+ // println!("challenge: {:?}", challenge);
+ // data = Some(challenge.sasl_data.as_bytes().to_owned());
+ }
+
+ // stepping the authentication exchange to completion
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let auth = Auth {
+ ns: "urn:ietf:params:xml:ns:xmpp-sasl".to_owned(),
+ mechanism: selected_mechanism.clone(),
+ sasl_data: Some(str::from_utf8(&sasl_data).unwrap().to_owned()),
+ };
+ let mut buffer = String::new();
+ let ser = Serializer::new(&mut buffer);
+ auth.serialize(ser).unwrap();
+ self.writer.write_all(buffer.as_bytes());
+ let node = self.get_node().await?;
+ let mut deserializer = Deserializer::from_str(&node);
+ let challenge = Challenge::deserialize(&mut deserializer).unwrap();
+ data = Some(challenge.sasl_data.as_bytes().to_owned());
+ }
+ self.start_stream().await?;
+ Ok(())
+ }
}
diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs
index 74b800c..d4225d3 100644
--- a/src/client/unencrypted.rs
+++ b/src/client/unencrypted.rs
@@ -115,14 +115,12 @@ impl<'j> JabberClient<'j> {
.connect(&self.jabber.server, stream)
.await
{
- let (read, write) = tokio::io::split(tlsstream);
+ let (read, writer) = tokio::io::split(tlsstream);
let reader = Reader::from_reader(BufReader::new(read));
- let writer = Writer::new(write);
- return Ok(super::encrypted::JabberClient::new(
- reader,
- writer,
- self.jabber,
- ));
+ let mut client =
+ super::encrypted::JabberClient::new(reader, writer, self.jabber);
+ client.start_stream().await?;
+ return Ok(client);
}
}
QName(_) => return Err(JabberError::TlsNegotiation),
diff --git a/src/error.rs b/src/error.rs
index a632537..20ebc3e 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,7 +1,44 @@
+use std::str::Utf8Error;
+
+use rsasl::mechname::MechanismNameError;
+
#[derive(Debug)]
pub enum JabberError {
- ConnectionError,
+ Connection,
BadStream,
StartTlsUnavailable,
TlsNegotiation,
+ Utf8Decode,
+ XML(quick_xml::Error),
+ SASL(SASLError),
+}
+
+#[derive(Debug)]
+pub enum SASLError {
+ SASL(rsasl::prelude::SASLError),
+ MechanismName(MechanismNameError),
+}
+
+impl From<rsasl::prelude::SASLError> for JabberError {
+ fn from(e: rsasl::prelude::SASLError) -> Self {
+ Self::SASL(SASLError::SASL(e))
+ }
+}
+
+impl From<MechanismNameError> for JabberError {
+ fn from(value: MechanismNameError) -> Self {
+ Self::SASL(SASLError::MechanismName(value))
+ }
+}
+
+impl From<Utf8Error> for JabberError {
+ fn from(e: Utf8Error) -> Self {
+ Self::Utf8Decode
+ }
+}
+
+impl From<quick_xml::Error> for JabberError {
+ fn from(e: quick_xml::Error) -> Self {
+ Self::XML(e)
+ }
}
diff --git a/src/jabber.rs b/src/jabber.rs
index a1f6272..a1b2a2f 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -1,33 +1,44 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
+use std::sync::Arc;
use quick_xml::{Reader, Writer};
+use rsasl::prelude::SASLConfig;
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
-use crate::client;
use crate::client::JabberClientType;
use crate::jid::JID;
+use crate::{client, JabberClient};
use crate::{JabberError, Result};
pub struct Jabber<'j> {
pub jid: JID,
- pub password: String,
+ pub auth: Arc<SASLConfig>,
pub server: String,
_marker: PhantomData<&'j ()>,
}
impl<'j> Jabber<'j> {
- pub fn new(jid: JID, password: String) -> Self {
+ pub fn new(jid: JID, password: String) -> Result<Self> {
let server = jid.domainpart.clone();
- Self {
+ let auth = SASLConfig::with_credentials(None, jid.as_bare().to_string(), password)?;
+ println!("auth: {:?}", auth);
+ Ok(Self {
jid,
- password,
+ auth,
server,
_marker: PhantomData,
- }
+ })
+ }
+
+ pub async fn login(&'j mut self) -> Result<JabberClient<'j>> {
+ let mut client = self.connect().await?.ensure_tls().await?;
+ println!("negotiation");
+ client.negotiate().await?;
+ Ok(client)
}
async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
@@ -106,9 +117,8 @@ impl<'j> Jabber<'j> {
.connect(&self.server, socket)
.await
{
- let (read, write) = tokio::io::split(stream);
+ let (read, writer) = tokio::io::split(stream);
let reader = Reader::from_reader(BufReader::new(read));
- let writer = Writer::new(write);
return Ok(JabberClientType::Encrypted(
client::encrypted::JabberClient::new(reader, writer, self),
));
@@ -126,6 +136,6 @@ impl<'j> Jabber<'j> {
}
}
}
- Err(JabberError::ConnectionError)
+ Err(JabberError::Connection)
}
}
diff --git a/src/jid/mod.rs b/src/jid/mod.rs
index 4baa857..b2a03ea 100644
--- a/src/jid/mod.rs
+++ b/src/jid/mod.rs
@@ -8,8 +8,13 @@ pub struct JID {
pub resourcepart: Option<String>,
}
+pub enum JIDError {
+ NoResourcePart,
+ ParseError(ParseError),
+}
+
#[derive(Debug)]
-pub enum JIDParseError {
+pub enum ParseError {
Empty,
Malformed,
}
@@ -26,15 +31,31 @@ impl JID {
resourcepart,
}
}
+
+ pub fn as_bare(&self) -> Self {
+ Self {
+ localpart: self.localpart.clone(),
+ domainpart: self.domainpart.clone(),
+ resourcepart: None,
+ }
+ }
+
+ pub fn as_full(&self) -> Result<&Self, JIDError> {
+ if let Some(_) = self.resourcepart {
+ Ok(&self)
+ } else {
+ Err(JIDError::NoResourcePart)
+ }
+ }
}
impl FromStr for JID {
- type Err = JIDParseError;
+ type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let split: Vec<&str> = s.split('@').collect();
match split.len() {
- 0 => Err(JIDParseError::Empty),
+ 0 => Err(ParseError::Empty),
1 => {
let split: Vec<&str> = split[0].split('/').collect();
match split.len() {
@@ -44,7 +65,7 @@ impl FromStr for JID {
split[0].to_string(),
Some(split[1].to_string()),
)),
- _ => Err(JIDParseError::Malformed),
+ _ => Err(ParseError::Malformed),
}
}
2 => {
@@ -60,16 +81,16 @@ impl FromStr for JID {
split2[0].to_string(),
Some(split2[1].to_string()),
)),
- _ => Err(JIDParseError::Malformed),
+ _ => Err(ParseError::Malformed),
}
}
- _ => Err(JIDParseError::Malformed),
+ _ => Err(ParseError::Malformed),
}
}
}
impl TryFrom<String> for JID {
- type Error = JIDParseError;
+ type Error = ParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
diff --git a/src/lib.rs b/src/lib.rs
index 7f1433d..d27f0ba 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -27,16 +27,26 @@ mod tests {
// println!("{:?}", jabber.get_sockets().await)
// }
+ // #[tokio::test]
+ // async fn connect() {
+ // Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
+ // .unwrap()
+ // .connect()
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap()
+ // .start_stream()
+ // .await
+ // .unwrap();
+ // }
+
#[tokio::test]
- async fn connect() {
- Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
- .connect()
- .await
- .unwrap()
- .ensure_tls()
- .await
+ async fn login() {
+ Jabber::new(JID::from_str("test@blos.sm").unwrap(), "slayed".to_owned())
.unwrap()
- .start_stream()
+ .login()
.await
.unwrap();
}
diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs
index baf29e0..4eaa4c2 100644
--- a/src/stanza/mod.rs
+++ b/src/stanza/mod.rs
@@ -1 +1,2 @@
+pub mod sasl;
pub mod stream;
diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs
new file mode 100644
index 0000000..c0e41ab
--- /dev/null
+++ b/src/stanza/sasl.rs
@@ -0,0 +1,32 @@
+use serde::{Deserialize, Serialize};
+
+#[derive(Deserialize, PartialEq, Debug)]
+pub struct Mechanisms {
+ #[serde(rename = "$value")]
+ pub mechanisms: Vec<Mechanism>,
+}
+
+#[derive(Deserialize, PartialEq, Debug)]
+pub struct Mechanism {
+ #[serde(rename = "$text")]
+ pub mechanism: String,
+}
+
+#[derive(Serialize, Debug)]
+#[serde(rename = "auth")]
+pub struct Auth {
+ #[serde(rename = "@xmlns")]
+ pub ns: String,
+ #[serde(rename = "@mechanism")]
+ pub mechanism: String,
+ #[serde(rename = "$text")]
+ pub sasl_data: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Challenge {
+ #[serde(rename = "@xmlns")]
+ pub ns: String,
+ #[serde(rename = "$text")]
+ pub sasl_data: String,
+}
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index dde741d..4c0addd 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
+use super::sasl::Mechanisms;
+
#[derive(Serialize, Deserialize)]
#[serde(rename = "stream:stream")]
struct Stream {
@@ -31,6 +33,9 @@ pub enum StreamFeature {
#[serde(rename = "starttls")]
StartTls,
// TODO: other stream features
- Sasl,
+ #[serde(rename = "mechanisms")]
+ Sasl(Mechanisms),
Bind,
+ #[serde(other)]
+ Unknown,
}