aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-11-29 02:11:02 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-11-29 02:11:02 +0000
commitb6593389069903cc4c85e40611296d8a240f718d (patch)
treeae4df92ea45cce5e5b904041a925263e8d629274
parent2dcbc9e1f4339993dd47b2659770a9cf4855b02d (diff)
downloadluz-b6593389069903cc4c85e40611296d8a240f718d.tar.gz
luz-b6593389069903cc4c85e40611296d8a240f718d.tar.bz2
luz-b6593389069903cc4c85e40611296d8a240f718d.zip
implement sasl kinda
-rw-r--r--Cargo.toml2
-rw-r--r--src/connection.rs35
-rw-r--r--src/error.rs2
-rw-r--r--src/jabber.rs220
-rw-r--r--src/stanza/sasl.rs169
-rw-r--r--src/stanza/stream.rs42
6 files changed, 422 insertions, 48 deletions
diff --git a/Cargo.toml b/Cargo.toml
index f136e90..326e45e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@ async-trait = "0.1.68"
lazy_static = "1.4.0"
nanoid = "0.4.0"
# TODO: remove unneeded features
-rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
+rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
tracing = "0.1.40"
diff --git a/src/connection.rs b/src/connection.rs
index 65e9383..9e485d3 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -1,16 +1,18 @@
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::str::FromStr;
+use std::sync::Arc;
+use rsasl::config::SASLConfig;
use tokio::net::TcpStream;
use tokio_native_tls::native_tls::TlsConnector;
// TODO: use rustls
use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace};
-use crate::Error;
use crate::Jabber;
use crate::Result;
+use crate::{Error, JID};
pub type Tls = TlsStream<TcpStream>;
pub type Unencrypted = TcpStream;
@@ -37,15 +39,20 @@ impl Connection {
}
}
- // pub async fn connect_user<J: TryInto<JID>>(jid: J, password: String) -> Result<Self> {
- // let server = jid.domainpart.clone();
- // let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
- // println!("auth: {:?}", auth);
- // Self::connect(&server, jid.try_into()?, Some(auth)).await
- // }
+ pub async fn connect_user(jid: impl AsRef<str>, password: String) -> Result<Self> {
+ let jid: JID = JID::from_str(jid.as_ref())?;
+ let server = jid.domainpart.clone();
+ let auth = SASLConfig::with_credentials(None, jid.localpart.clone().unwrap(), password)?;
+ println!("auth: {:?}", auth);
+ Self::connect(&server, Some(jid), Some(auth)).await
+ }
#[instrument]
- pub async fn connect(server: &str) -> Result<Self> {
+ pub async fn connect(
+ server: &str,
+ jid: Option<JID>,
+ auth: Option<Arc<SASLConfig>>,
+ ) -> Result<Self> {
info!("connecting to {}", server);
let sockets = Self::get_sockets(&server).await;
debug!("discovered sockets: {:?}", sockets);
@@ -58,8 +65,8 @@ impl Connection {
return Ok(Self::Encrypted(Jabber::new(
readhalf,
writehalf,
- None,
- None,
+ jid,
+ auth,
server.to_owned(),
)));
}
@@ -71,8 +78,8 @@ impl Connection {
return Ok(Self::Unencrypted(Jabber::new(
readhalf,
writehalf,
- None,
- None,
+ jid,
+ auth,
server.to_owned(),
)));
}
@@ -181,12 +188,12 @@ mod tests {
#[test(tokio::test)]
async fn connect() {
- Connection::connect("blos.sm").await.unwrap();
+ Connection::connect("blos.sm", None, None).await.unwrap();
}
#[test(tokio::test)]
async fn test_tls() {
- Connection::connect("blos.sm")
+ Connection::connect("blos.sm", None, None)
.await
.unwrap()
.ensure_tls()
diff --git a/src/error.rs b/src/error.rs
index c7c867c..8ee9077 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -19,6 +19,8 @@ pub enum Error {
IDMismatch,
BindError,
ParseError,
+ Negotiation,
+ TlsRequired,
UnexpectedEnd,
UnexpectedElement,
UnexpectedText,
diff --git a/src/jabber.rs b/src/jabber.rs
index a56c65c..9e7f9d8 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -1,26 +1,26 @@
use std::str;
use std::sync::Arc;
+use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
-use rsasl::prelude::SASLConfig;
+use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
+use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace};
use trust_dns_resolver::proto::rr::domain::IntoLabel;
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
+use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
-use crate::stanza::stream::{Features, Stream};
+use crate::stanza::stream::{Feature, Features, Stream};
use crate::stanza::XML_VERSION;
-use crate::Result;
use crate::JID;
+use crate::{Connection, Result};
-pub struct Jabber<S>
-where
- S: AsyncRead + AsyncWrite + Unpin,
-{
+pub struct Jabber<S> {
reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>,
jid: Option<JID>,
@@ -56,7 +56,89 @@ where
S: AsyncRead + AsyncWrite + Unpin + Send,
Jabber<S>: std::fmt::Debug,
{
- // pub async fn negotiate(self) -> Result<Jabber<S>> {}
+ pub async fn sasl(
+ &mut self,
+ mechanisms: Mechanisms,
+ sasl_config: Arc<SASLConfig>,
+ ) -> Result<()> {
+ let sasl = SASLClient::new(sasl_config);
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in &mechanisms.mechanisms {
+ offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
+ }
+ debug!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ debug!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: "=".to_string(),
+ };
+ self.writer.write_full(&auth).await?;
+ // get challenge data
+ let challenge: Challenge = self.reader.read().await?;
+ debug!("challenge: {:?}", challenge);
+ data = Some((*challenge).as_bytes().to_vec());
+ debug!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: str::from_utf8(&sasl_data)?.to_string(),
+ };
+ debug!("{:?}", auth);
+ self.writer.write_full(&auth).await?;
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
+ }
+ debug!("we went first");
+ }
+
+ // stepping the authentication exchange to completion
+ if data != None {
+ debug!("data: {:?}", data);
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
+ debug!("response: {:?}", response);
+ self.writer.write_full(&response).await?;
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub async fn bind(&mut self) -> Result<()> {
+ todo!()
+ }
#[instrument]
pub async fn start_stream(&mut self) -> Result<()> {
@@ -76,6 +158,8 @@ where
let decl = self.reader.read_prolog().await?;
// receive stream element and validate
+ let text = str::from_utf8(self.reader.buffer.data()).unwrap();
+ debug!("data: {}", text);
let stream: Stream = self.reader.read_start().await?;
debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
@@ -98,6 +182,87 @@ where
}
impl Jabber<Unencrypted> {
+ pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> {
+ self.start_stream().await?;
+ // TODO: timeout
+ let features = self.get_features().await?.features;
+ if let Some(Feature::StartTls(_)) = features
+ .iter()
+ .find(|feature| matches!(feature, Feature::StartTls(_s)))
+ {
+ let jabber = self.starttls().await?;
+ let jabber = jabber.negotiate().await?;
+ return Ok(jabber);
+ } else {
+ // TODO: better error
+ return Err(Error::TlsRequired);
+ }
+ }
+
+ #[async_recursion]
+ pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
+ self.start_stream().await?;
+ // TODO: timeout
+ let features = self.get_features().await?.features;
+ if let Some(Feature::StartTls(_)) = features
+ .iter()
+ .find(|feature| matches!(feature, Feature::StartTls(_s)))
+ {
+ let jabber = self.starttls().await?;
+ let jabber = jabber.negotiate().await?;
+ return Ok(Connection::Encrypted(jabber));
+ } else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
+ self.auth.clone(),
+ features
+ .iter()
+ .find(|feature| matches!(feature, Feature::Sasl(_))),
+ ) {
+ self.sasl(mechanisms.clone(), sasl_config).await?;
+ let jabber = self.negotiate_tls_optional().await?;
+ Ok(jabber)
+ } else if let Some(Feature::Bind) = features
+ .iter()
+ .find(|feature| matches!(feature, Feature::Bind))
+ {
+ self.bind().await?;
+ Ok(Connection::Unencrypted(self))
+ } else {
+ // TODO: better error
+ return Err(Error::Negotiation);
+ }
+ }
+}
+
+impl Jabber<Tls> {
+ #[async_recursion]
+ pub async fn negotiate(mut self) -> Result<Jabber<Tls>> {
+ self.start_stream().await?;
+ let features = self.get_features().await?.features;
+
+ if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
+ self.auth.clone(),
+ features
+ .iter()
+ .find(|feature| matches!(feature, Feature::Sasl(_))),
+ ) {
+ // TODO: avoid clone
+ self.sasl(mechanisms.clone(), sasl_config).await?;
+ let jabber = self.negotiate().await?;
+ Ok(jabber)
+ } else if let Some(Feature::Bind) = features
+ .iter()
+ .find(|feature| matches!(feature, Feature::Bind))
+ {
+ self.bind().await?;
+ Ok(self)
+ } else {
+ // TODO: better error
+ return Err(Error::Negotiation);
+ }
+ }
+}
+
+impl Jabber<Unencrypted> {
pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
self.writer
.write_full(&StartTls { required: false })
@@ -155,10 +320,47 @@ mod tests {
#[test(tokio::test)]
async fn start_stream() {
- let connection = Connection::connect("blos.sm").await.unwrap();
+ let connection = Connection::connect("blos.sm", None, None).await.unwrap();
match connection {
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
}
}
+
+ #[test(tokio::test)]
+ async fn sasl() {
+ let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
+ .await
+ .unwrap()
+ .ensure_tls()
+ .await
+ .unwrap();
+ let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ println!("data: {}", text);
+ jabber.start_stream().await.unwrap();
+
+ let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ println!("data: {}", text);
+ jabber.reader.read_buf().await.unwrap();
+ let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ println!("data: {}", text);
+
+ let features = jabber.get_features().await.unwrap();
+ let (sasl_config, feature) = (
+ jabber.auth.clone().unwrap(),
+ features
+ .features
+ .iter()
+ .find(|feature| matches!(feature, Feature::Sasl(_)))
+ .unwrap(),
+ );
+ match feature {
+ Feature::StartTls(_start_tls) => todo!(),
+ Feature::Sasl(mechanisms) => {
+ jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
+ }
+ Feature::Bind => todo!(),
+ Feature::Unknown => todo!(),
+ }
+ }
}
diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs
index 8b13789..6ac4fc9 100644
--- a/src/stanza/sasl.rs
+++ b/src/stanza/sasl.rs
@@ -1 +1,170 @@
+use std::ops::Deref;
+use peanuts::{
+ element::{FromElement, IntoElement},
+ DeserializeError, Element,
+};
+use tracing::debug;
+
+pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
+
+#[derive(Debug, Clone)]
+pub struct Mechanisms {
+ pub mechanisms: Vec<String>,
+}
+
+impl FromElement for Mechanisms {
+ fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("mechanisms")?;
+ element.check_namespace(XMLNS)?;
+ debug!("getting mechanisms");
+ let mechanisms: Vec<Mechanism> = element.pop_children()?;
+ debug!("gottting mechanisms");
+ let mechanisms = mechanisms
+ .into_iter()
+ .map(|Mechanism(mechanism)| mechanism)
+ .collect();
+ debug!("gottting mechanisms");
+
+ Ok(Mechanisms { mechanisms })
+ }
+}
+
+impl IntoElement for Mechanisms {
+ fn builder(&self) -> peanuts::element::ElementBuilder {
+ Element::builder("mechanisms", Some(XMLNS)).push_children(
+ self.mechanisms
+ .iter()
+ .map(|mechanism| Mechanism(mechanism.to_string()))
+ .collect(),
+ )
+ }
+}
+
+pub struct Mechanism(String);
+
+impl FromElement for Mechanism {
+ fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("mechanism")?;
+ element.check_namespace(XMLNS)?;
+
+ let mechanism = element.pop_value()?;
+
+ Ok(Mechanism(mechanism))
+ }
+}
+
+impl IntoElement for Mechanism {
+ fn builder(&self) -> peanuts::element::ElementBuilder {
+ Element::builder("mechanism", Some(XMLNS)).push_text(self.0.clone())
+ }
+}
+
+impl Deref for Mechanism {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[derive(Debug)]
+pub struct Auth {
+ pub mechanism: String,
+ pub sasl_data: String,
+}
+
+impl IntoElement for Auth {
+ fn builder(&self) -> peanuts::element::ElementBuilder {
+ Element::builder("auth", Some(XMLNS))
+ .push_attribute("mechanism", self.mechanism.clone())
+ .push_text(self.sasl_data.clone())
+ }
+}
+
+#[derive(Debug)]
+pub struct Challenge(String);
+
+impl Deref for Challenge {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl FromElement for Challenge {
+ fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("challenge")?;
+ element.check_namespace(XMLNS)?;
+
+ let sasl_data = element.value()?;
+
+ Ok(Challenge(sasl_data))
+ }
+}
+
+#[derive(Debug)]
+pub struct Success(String);
+
+impl Deref for Success {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl FromElement for Success {
+ fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("success")?;
+ element.check_namespace(XMLNS)?;
+
+ let sasl_data = element.value()?;
+
+ Ok(Success(sasl_data))
+ }
+}
+
+#[derive(Debug)]
+pub enum ServerResponse {
+ Challenge(Challenge),
+ Success(Success),
+}
+
+impl FromElement for ServerResponse {
+ fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
+ match element.identify() {
+ (Some(XMLNS), "challenge") => {
+ Ok(ServerResponse::Challenge(Challenge::from_element(element)?))
+ }
+ (Some(XMLNS), "success") => {
+ Ok(ServerResponse::Success(Success::from_element(element)?))
+ }
+ _ => Err(DeserializeError::UnexpectedElement(element)),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Response(String);
+
+impl Response {
+ pub fn new(response: String) -> Self {
+ Self(response)
+ }
+}
+
+impl Deref for Response {
+ type Target = str;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl IntoElement for Response {
+ fn builder(&self) -> peanuts::element::ElementBuilder {
+ Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone())
+ }
+}
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index 40f6ba0..fecace5 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -3,9 +3,11 @@ use std::collections::{HashMap, HashSet};
use peanuts::element::{Content, ElementBuilder, FromElement, IntoElement, NamespaceDeclaration};
use peanuts::XML_NS;
use peanuts::{element::Name, Element};
+use tracing::debug;
use crate::{Error, JID};
+use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls};
pub const XMLNS: &str = "http://etherx.jabber.org/streams";
@@ -92,32 +94,12 @@ impl<'s> Stream {
#[derive(Debug)]
pub struct Features {
- features: Vec<Feature>,
+ pub features: Vec<Feature>,
}
impl IntoElement for Features {
fn builder(&self) -> ElementBuilder {
Element::builder("features", Some(XMLNS)).push_children(self.features.clone())
- // let mut content = Vec::new();
- // for feature in &self.features {
- // match feature {
- // Feature::StartTls(start_tls) => {
- // content.push(Content::Element(start_tls.into_element()))
- // }
- // Feature::Sasl => {}
- // Feature::Bind => {}
- // Feature::Unknown => {}
- // }
- // }
- // Element {
- // name: Name {
- // namespace: Some(XMLNS.to_string()),
- // local_name: "features".to_string(),
- // },
- // namespace_declaration_overrides: HashSet::new(),
- // attributes: HashMap::new(),
- // content,
- // }
}
}
@@ -128,7 +110,9 @@ impl FromElement for Features {
element.check_namespace(XMLNS)?;
element.check_name("features")?;
+ debug!("got features stanza");
let features = element.children()?;
+ debug!("got features period");
Ok(Features { features })
}
@@ -137,7 +121,7 @@ impl FromElement for Features {
#[derive(Debug, Clone)]
pub enum Feature {
StartTls(StartTls),
- Sasl,
+ Sasl(Mechanisms),
Bind,
Unknown,
}
@@ -146,7 +130,7 @@ impl IntoElement for Feature {
fn builder(&self) -> ElementBuilder {
match self {
Feature::StartTls(start_tls) => start_tls.builder(),
- Feature::Sasl => todo!(),
+ Feature::Sasl(mechanisms) => mechanisms.builder(),
Feature::Bind => todo!(),
Feature::Unknown => todo!(),
}
@@ -155,11 +139,21 @@ impl IntoElement for Feature {
impl FromElement for Feature {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
+ let identity = element.identify();
+ debug!("identity: {:?}", identity);
match element.identify() {
(Some(starttls::XMLNS), "starttls") => {
+ debug!("identified starttls");
Ok(Feature::StartTls(StartTls::from_element(element)?))
}
- _ => Ok(Feature::Unknown),
+ (Some(sasl::XMLNS), "mechanisms") => {
+ debug!("identified mechanisms");
+ Ok(Feature::Sasl(Mechanisms::from_element(element)?))
+ }
+ _ => {
+ debug!("identified unknown feature");
+ Ok(Feature::Unknown)
+ }
}
}
}