aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-11-29 17:07:16 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-11-29 17:07:16 +0000
commit859a19820d69eca5fca87fc01acad72a6355f97e (patch)
treecf3736c44d93377a16a09d9eaa95851c23aaff80
parentb6593389069903cc4c85e40611296d8a240f718d (diff)
downloadluz-859a19820d69eca5fca87fc01acad72a6355f97e.tar.gz
luz-859a19820d69eca5fca87fc01acad72a6355f97e.tar.bz2
luz-859a19820d69eca5fca87fc01acad72a6355f97e.zip
add sasl failure type
-rw-r--r--Cargo.toml2
-rw-r--r--src/error.rs3
-rw-r--r--src/jabber.rs16
-rw-r--r--src/stanza/sasl.rs84
4 files changed, 96 insertions, 9 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 326e45e..e9c12b5 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@ async-trait = "0.1.68"
lazy_static = "1.4.0"
nanoid = "0.4.0"
# TODO: remove unneeded features
-rsasl = { version = "2.0.1", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
+rsasl = { version = "2.0.1", path = "../rsasl", default_features = false, features = ["provider_base64", "plain", "config_builder", "scram-sha-1"] }
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
tracing = "0.1.40"
diff --git a/src/error.rs b/src/error.rs
index 8ee9077..a1f853b 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -2,7 +2,7 @@ use std::str::Utf8Error;
use rsasl::mechname::MechanismNameError;
-use crate::jid::ParseError;
+use crate::{jid::ParseError, stanza::sasl::Failure};
#[derive(Debug)]
pub enum Error {
@@ -27,6 +27,7 @@ pub enum Error {
XML(peanuts::Error),
SASL(SASLError),
JID(ParseError),
+ Authentication(Failure),
}
#[derive(Debug)]
diff --git a/src/jabber.rs b/src/jabber.rs
index 9e7f9d8..599879d 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -5,7 +5,7 @@ use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
-use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
+use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf};
use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace};
@@ -102,7 +102,10 @@ where
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
- ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
debug!("we went first");
}
@@ -121,7 +124,11 @@ where
// While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response);
+ let stdout = tokio::io::stdout();
+ let mut writer = Writer::new(stdout);
+ writer.write_full(&response).await?;
self.writer.write_full(&response).await?;
+ debug!("response written");
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
@@ -129,7 +136,10 @@ where
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
- ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()),
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
}
}
diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs
index 6ac4fc9..ec6f63c 100644
--- a/src/stanza/sasl.rs
+++ b/src/stanza/sasl.rs
@@ -105,10 +105,10 @@ impl FromElement for Challenge {
}
#[derive(Debug)]
-pub struct Success(String);
+pub struct Success(Option<String>);
impl Deref for Success {
- type Target = str;
+ type Target = Option<String>;
fn deref(&self) -> &Self::Target {
&self.0
@@ -120,7 +120,7 @@ impl FromElement for Success {
element.check_name("success")?;
element.check_namespace(XMLNS)?;
- let sasl_data = element.value()?;
+ let sasl_data = element.value_opt()?;
Ok(Success(sasl_data))
}
@@ -130,10 +130,12 @@ impl FromElement for Success {
pub enum ServerResponse {
Challenge(Challenge),
Success(Success),
+ Failure(Failure),
}
impl FromElement for ServerResponse {
fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
+ debug!("identification: {:?}", element.identify());
match element.identify() {
(Some(XMLNS), "challenge") => {
Ok(ServerResponse::Challenge(Challenge::from_element(element)?))
@@ -141,6 +143,9 @@ impl FromElement for ServerResponse {
(Some(XMLNS), "success") => {
Ok(ServerResponse::Success(Success::from_element(element)?))
}
+ (Some(XMLNS), "failure") => {
+ Ok(ServerResponse::Failure(Failure::from_element(element)?))
+ }
_ => Err(DeserializeError::UnexpectedElement(element)),
}
}
@@ -165,6 +170,77 @@ impl Deref for Response {
impl IntoElement for Response {
fn builder(&self) -> peanuts::element::ElementBuilder {
- Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone())
+ Element::builder("response", Some(XMLNS)).push_text(self.0.clone())
+ }
+}
+
+#[derive(Debug)]
+pub struct Failure {
+ r#type: Option<FailureType>,
+ text: Option<Text>,
+}
+
+impl FromElement for Failure {
+ fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("failure")?;
+ element.check_namespace(XMLNS)?;
+
+ let r#type = element.pop_child_opt()?;
+ let text = element.pop_child_opt()?;
+
+ Ok(Failure { r#type, text })
+ }
+}
+
+#[derive(Debug)]
+pub enum FailureType {
+ Aborted,
+ AccountDisabled,
+ CredentialsExpired,
+ EncryptionRequired,
+ IncorrectEncoding,
+ InvalidAuthzid,
+ InvalidMechanism,
+ MalformedRequest,
+ MechanismTooWeak,
+ NotAuthorized,
+ TemporaryAuthFailure,
+}
+
+impl FromElement for FailureType {
+ fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
+ match element.identify() {
+ (Some(XMLNS), "aborted") => Ok(FailureType::Aborted),
+ (Some(XMLNS), "account-disabled") => Ok(FailureType::AccountDisabled),
+ (Some(XMLNS), "credentials-expired") => Ok(FailureType::CredentialsExpired),
+ (Some(XMLNS), "encryption-required") => Ok(FailureType::EncryptionRequired),
+ (Some(XMLNS), "incorrect-encoding") => Ok(FailureType::IncorrectEncoding),
+ (Some(XMLNS), "invalid-authzid") => Ok(FailureType::InvalidAuthzid),
+ (Some(XMLNS), "invalid-mechanism") => Ok(FailureType::InvalidMechanism),
+ (Some(XMLNS), "malformed-request") => Ok(FailureType::MalformedRequest),
+ (Some(XMLNS), "mechanism-too-weak") => Ok(FailureType::MechanismTooWeak),
+ (Some(XMLNS), "not-authorized") => Ok(FailureType::NotAuthorized),
+ (Some(XMLNS), "temporary-auth-failure") => Ok(FailureType::TemporaryAuthFailure),
+ _ => Err(DeserializeError::UnexpectedElement(element)),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Text {
+ lang: Option<String>,
+ text: Option<String>,
+}
+
+impl FromElement for Text {
+ fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
+ element.check_name("text")?;
+ element.check_namespace(XMLNS)?;
+
+ let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?;
+
+ let text = element.pop_value_opt()?;
+
+ Ok(Text { lang, text })
}
}