diff options
Diffstat (limited to '')
| -rw-r--r-- | src/error.rs | 3 | ||||
| -rw-r--r-- | src/jabber.rs | 16 | ||||
| -rw-r--r-- | src/stanza/sasl.rs | 84 | 
3 files changed, 95 insertions, 8 deletions
| diff --git a/src/error.rs b/src/error.rs index 8ee9077..a1f853b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use std::str::Utf8Error;  use rsasl::mechname::MechanismNameError; -use crate::jid::ParseError; +use crate::{jid::ParseError, stanza::sasl::Failure};  #[derive(Debug)]  pub enum Error { @@ -27,6 +27,7 @@ pub enum Error {      XML(peanuts::Error),      SASL(SASLError),      JID(ParseError), +    Authentication(Failure),  }  #[derive(Debug)] diff --git a/src/jabber.rs b/src/jabber.rs index 9e7f9d8..599879d 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -5,7 +5,7 @@ use async_recursion::async_recursion;  use peanuts::element::{FromElement, IntoElement};  use peanuts::{Reader, Writer};  use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf};  use tokio::time::timeout;  use tokio_native_tls::native_tls::TlsConnector;  use tracing::{debug, info, instrument, trace}; @@ -102,7 +102,10 @@ where                  ServerResponse::Challenge(challenge) => {                      data = Some((*challenge).as_bytes().to_vec())                  } -                ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), +                ServerResponse::Success(success) => { +                    data = success.clone().map(|success| success.as_bytes().to_vec()) +                } +                ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),              }              debug!("we went first");          } @@ -121,7 +124,11 @@ where                  // While we aren't finished, receive more data from the other party                  let response = Response::new(str::from_utf8(&sasl_data)?.to_string());                  debug!("response: {:?}", response); +                let stdout = tokio::io::stdout(); +                let mut writer = Writer::new(stdout); +                writer.write_full(&response).await?;                  self.writer.write_full(&response).await?; +                debug!("response written");                  let server_response: ServerResponse = self.reader.read().await?;                  debug!("server_response: {:#?}", server_response); @@ -129,7 +136,10 @@ where                      ServerResponse::Challenge(challenge) => {                          data = Some((*challenge).as_bytes().to_vec())                      } -                    ServerResponse::Success(success) => data = Some((*success).as_bytes().to_vec()), +                    ServerResponse::Success(success) => { +                        data = success.clone().map(|success| success.as_bytes().to_vec()) +                    } +                    ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),                  }              }          } diff --git a/src/stanza/sasl.rs b/src/stanza/sasl.rs index 6ac4fc9..ec6f63c 100644 --- a/src/stanza/sasl.rs +++ b/src/stanza/sasl.rs @@ -105,10 +105,10 @@ impl FromElement for Challenge {  }  #[derive(Debug)] -pub struct Success(String); +pub struct Success(Option<String>);  impl Deref for Success { -    type Target = str; +    type Target = Option<String>;      fn deref(&self) -> &Self::Target {          &self.0 @@ -120,7 +120,7 @@ impl FromElement for Success {          element.check_name("success")?;          element.check_namespace(XMLNS)?; -        let sasl_data = element.value()?; +        let sasl_data = element.value_opt()?;          Ok(Success(sasl_data))      } @@ -130,10 +130,12 @@ impl FromElement for Success {  pub enum ServerResponse {      Challenge(Challenge),      Success(Success), +    Failure(Failure),  }  impl FromElement for ServerResponse {      fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { +        debug!("identification: {:?}", element.identify());          match element.identify() {              (Some(XMLNS), "challenge") => {                  Ok(ServerResponse::Challenge(Challenge::from_element(element)?)) @@ -141,6 +143,9 @@ impl FromElement for ServerResponse {              (Some(XMLNS), "success") => {                  Ok(ServerResponse::Success(Success::from_element(element)?))              } +            (Some(XMLNS), "failure") => { +                Ok(ServerResponse::Failure(Failure::from_element(element)?)) +            }              _ => Err(DeserializeError::UnexpectedElement(element)),          }      } @@ -165,6 +170,77 @@ impl Deref for Response {  impl IntoElement for Response {      fn builder(&self) -> peanuts::element::ElementBuilder { -        Element::builder("reponse", Some(XMLNS)).push_text(self.0.clone()) +        Element::builder("response", Some(XMLNS)).push_text(self.0.clone()) +    } +} + +#[derive(Debug)] +pub struct Failure { +    r#type: Option<FailureType>, +    text: Option<Text>, +} + +impl FromElement for Failure { +    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("failure")?; +        element.check_namespace(XMLNS)?; + +        let r#type = element.pop_child_opt()?; +        let text = element.pop_child_opt()?; + +        Ok(Failure { r#type, text }) +    } +} + +#[derive(Debug)] +pub enum FailureType { +    Aborted, +    AccountDisabled, +    CredentialsExpired, +    EncryptionRequired, +    IncorrectEncoding, +    InvalidAuthzid, +    InvalidMechanism, +    MalformedRequest, +    MechanismTooWeak, +    NotAuthorized, +    TemporaryAuthFailure, +} + +impl FromElement for FailureType { +    fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> { +        match element.identify() { +            (Some(XMLNS), "aborted") => Ok(FailureType::Aborted), +            (Some(XMLNS), "account-disabled") => Ok(FailureType::AccountDisabled), +            (Some(XMLNS), "credentials-expired") => Ok(FailureType::CredentialsExpired), +            (Some(XMLNS), "encryption-required") => Ok(FailureType::EncryptionRequired), +            (Some(XMLNS), "incorrect-encoding") => Ok(FailureType::IncorrectEncoding), +            (Some(XMLNS), "invalid-authzid") => Ok(FailureType::InvalidAuthzid), +            (Some(XMLNS), "invalid-mechanism") => Ok(FailureType::InvalidMechanism), +            (Some(XMLNS), "malformed-request") => Ok(FailureType::MalformedRequest), +            (Some(XMLNS), "mechanism-too-weak") => Ok(FailureType::MechanismTooWeak), +            (Some(XMLNS), "not-authorized") => Ok(FailureType::NotAuthorized), +            (Some(XMLNS), "temporary-auth-failure") => Ok(FailureType::TemporaryAuthFailure), +            _ => Err(DeserializeError::UnexpectedElement(element)), +        } +    } +} + +#[derive(Debug)] +pub struct Text { +    lang: Option<String>, +    text: Option<String>, +} + +impl FromElement for Text { +    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> { +        element.check_name("text")?; +        element.check_namespace(XMLNS)?; + +        let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?; + +        let text = element.pop_value_opt()?; + +        Ok(Text { lang, text })      }  } | 
