diff options
Diffstat (limited to '')
| -rw-r--r-- | jabber/Cargo.toml | 1 | ||||
| -rw-r--r-- | jabber/src/client.rs | 10 | ||||
| -rw-r--r-- | jabber/src/error.rs | 97 | ||||
| -rw-r--r-- | jabber/src/jabber_stream.rs | 11 | 
4 files changed, 45 insertions, 74 deletions
| diff --git a/jabber/Cargo.toml b/jabber/Cargo.toml index d070838..b6093cf 100644 --- a/jabber/Cargo.toml +++ b/jabber/Cargo.toml @@ -30,6 +30,7 @@ futures = "0.3.31"  take_mut = "0.2.2"  pin-project-lite = "0.2.15"  pin-project = "1.1.7" +thiserror = "2.0.11"  [dev-dependencies]  test-log = { version = "0.2", features = ["trace"] } diff --git a/jabber/src/client.rs b/jabber/src/client.rs index 1662483..de2be08 100644 --- a/jabber/src/client.rs +++ b/jabber/src/client.rs @@ -19,7 +19,8 @@ pub async fn connect_and_login(          None,          jid.localpart.clone().ok_or(Error::NoLocalpart)?,          password.as_ref().to_string(), -    )?; +    ) +    .map_err(|e| Error::SASL(e.into()))?;      let mut conn_state = Connecting::start(&server).await?;      loop {          match conn_state { @@ -108,9 +109,8 @@ pub enum InsecureConnecting {  #[cfg(test)]  mod tests { -    use std::{sync::Arc, time::Duration}; +    use std::time::Duration; -    use futures::{SinkExt, StreamExt};      use jid::JID;      use stanza::{          client::{ @@ -120,7 +120,7 @@ mod tests {          xep_0199::Ping,      };      use test_log::test; -    use tokio::{sync::Mutex, time::sleep}; +    use tokio::time::sleep;      use tracing::info;      use super::connect_and_login; @@ -128,7 +128,7 @@ mod tests {      #[test(tokio::test)]      async fn login() {          let mut jid: JID = "test@blos.sm".try_into().unwrap(); -        let client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) +        let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())              .await              .unwrap();          sleep(Duration::from_secs(5)).await diff --git a/jabber/src/error.rs b/jabber/src/error.rs index 902061e..8c27cc9 100644 --- a/jabber/src/error.rs +++ b/jabber/src/error.rs @@ -5,83 +5,50 @@ use rsasl::mechname::MechanismNameError;  use stanza::client::error::Error as ClientError;  use stanza::sasl::Failure;  use stanza::stream::Error as StreamError; +use thiserror::Error;  use tokio::task::JoinError; -#[derive(Debug)] +#[derive(Error, Debug)]  pub enum Error { +    #[error("connection")]      Connection, -    Utf8Decode, +    #[error("utf8 decode: {0}")] +    Utf8Decode(#[from] Utf8Error), +    #[error("negotiation")]      Negotiation, +    #[error("tls required")]      TlsRequired, +    #[error("already connected with tls")]      AlreadyTls, +    // TODO: specify unsupported feature +    #[error("unsupported feature")]      Unsupported, +    #[error("jid missing localpart")]      NoLocalpart, -    AlreadyConnecting, -    StreamClosed, +    #[error("received unexpected element: {0:?}")]      UnexpectedElement(peanuts::Element), -    XML(peanuts::Error), -    Deserialization(peanuts::DeserializeError), -    SASL(SASLError), -    JID(ParseError), -    Authentication(Failure), -    ClientError(ClientError), -    StreamError(StreamError), +    #[error("xml error: {0}")] +    XML(#[from] peanuts::Error), +    #[error("sasl error: {0}")] +    SASL(#[from] SASLError), +    #[error("jid error: {0}")] +    JID(#[from] ParseError), +    #[error("client stanza error: {0}")] +    ClientError(#[from] ClientError), +    #[error("stream error: {0}")] +    StreamError(#[from] StreamError), +    #[error("error missing")]      MissingError, -    Disconnected, -    Connecting, -    JoinError(JoinError), +    #[error("task join error")] +    JoinError(#[from] JoinError),  } -#[derive(Debug)] +#[derive(Error, Debug)]  pub enum SASLError { -    SASL(rsasl::prelude::SASLError), -    MechanismName(MechanismNameError), -} - -impl From<rsasl::prelude::SASLError> for Error { -    fn from(e: rsasl::prelude::SASLError) -> Self { -        Self::SASL(SASLError::SASL(e)) -    } -} - -impl From<JoinError> for Error { -    fn from(e: JoinError) -> Self { -        Self::JoinError(e) -    } -} - -impl From<peanuts::DeserializeError> for Error { -    fn from(e: peanuts::DeserializeError) -> Self { -        Error::Deserialization(e) -    } -} - -impl From<MechanismNameError> for Error { -    fn from(e: MechanismNameError) -> Self { -        Self::SASL(SASLError::MechanismName(e)) -    } -} - -impl From<SASLError> for Error { -    fn from(e: SASLError) -> Self { -        Self::SASL(e) -    } -} - -impl From<Utf8Error> for Error { -    fn from(_e: Utf8Error) -> Self { -        Self::Utf8Decode -    } -} - -impl From<peanuts::Error> for Error { -    fn from(e: peanuts::Error) -> Self { -        Self::XML(e) -    } -} - -impl From<ParseError> for Error { -    fn from(e: ParseError) -> Self { -        Self::JID(e) -    } +    #[error("sasl error: {0}")] +    SASL(#[from] rsasl::prelude::SASLError), +    #[error("mechanism error: {0}")] +    MechanismName(#[from] MechanismNameError), +    #[error("authentication failure: {0}")] +    Authentication(#[from] Failure),  } diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs index 6fa92b5..302350d 100644 --- a/jabber/src/jabber_stream.rs +++ b/jabber/src/jabber_stream.rs @@ -133,10 +133,13 @@ where          let sasl = SASLClient::new(sasl_config);          let mut offered_mechs: Vec<&Mechname> = Vec::new();          for mechanism in &mechanisms.mechanisms { -            offered_mechs.push(Mechname::parse(mechanism.as_bytes())?) +            offered_mechs +                .push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)          }          debug!("{:?}", offered_mechs); -        let mut session = sasl.start_suggested(&offered_mechs)?; +        let mut session = sasl +            .start_suggested(&offered_mechs) +            .map_err(|e| Error::SASL(e.into()))?;          let selected_mechanism = session.get_mechname().as_str().to_owned();          debug!("selected mech: {:?}", selected_mechanism);          let mut data: Option<Vec<u8>>; @@ -174,7 +177,7 @@ where                  ServerResponse::Success(success) => {                      data = success.clone().map(|success| success.as_bytes().to_vec())                  } -                ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), +                ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),              }              debug!("we went first");          } @@ -205,7 +208,7 @@ where                      ServerResponse::Success(success) => {                          data = success.clone().map(|success| success.as_bytes().to_vec())                      } -                    ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)), +                    ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),                  }              }          } | 
