aboutsummaryrefslogblamecommitdiffstats
path: root/stanza/src/sasl.rs
blob: 598a91b2b6f6a96d2c3ecd6f1e5d868cf7290c6a (plain) (tree)
1
2
3
4
5
6
7
                                    
 



                                        
                     











                                                                                        
                                                                 



                                                  














































































                                                                                                 
                                   

                        
                                 










                                                                                        
                                             








                              
                     










                                                                                    


                                                                            























                                                                   



                                                                           
                              




                                



































                                                                                           











                                                                                        
                              
                      
                       
            
                                
                    
                                   
                       
                                   
                       
                                  
                      
                               
                   
                                 
                     
                                 
                     
                                  
                     
                              
                  
                                      





















                                                                                             
                       
                 
                       













                                                                                        

     
use std::{fmt::Display, ops::Deref};

use peanuts::{
    element::{FromElement, IntoElement},
    DeserializeError, Element,
};
use thiserror::Error;

pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-sasl";

#[derive(Debug, Clone)]
pub struct Mechanisms {
    pub mechanisms: Vec<String>,
}

impl FromElement for Mechanisms {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("mechanisms")?;
        element.check_namespace(XMLNS)?;
        let mechanisms: Vec<Mechanism> = element.pop_children()?;
        let mechanisms = mechanisms
            .into_iter()
            .map(|Mechanism(mechanism)| mechanism)
            .collect();
        Ok(Mechanisms { mechanisms })
    }
}

impl IntoElement for Mechanisms {
    fn builder(&self) -> peanuts::element::ElementBuilder {
        Element::builder("mechanisms", Some(XMLNS)).push_children(
            self.mechanisms
                .iter()
                .map(|mechanism| Mechanism(mechanism.to_string()))
                .collect(),
        )
    }
}

pub struct Mechanism(String);

impl FromElement for Mechanism {
    fn from_element(mut element: peanuts::Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("mechanism")?;
        element.check_namespace(XMLNS)?;

        let mechanism = element.pop_value()?;

        Ok(Mechanism(mechanism))
    }
}

impl IntoElement for Mechanism {
    fn builder(&self) -> peanuts::element::ElementBuilder {
        Element::builder("mechanism", Some(XMLNS)).push_text(self.0.clone())
    }
}

impl Deref for Mechanism {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

#[derive(Debug)]
pub struct Auth {
    pub mechanism: String,
    pub sasl_data: String,
}

impl IntoElement for Auth {
    fn builder(&self) -> peanuts::element::ElementBuilder {
        Element::builder("auth", Some(XMLNS))
            .push_attribute("mechanism", self.mechanism.clone())
            .push_text(self.sasl_data.clone())
    }
}

#[derive(Debug)]
pub struct Challenge(String);

impl Deref for Challenge {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl FromElement for Challenge {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("challenge")?;
        element.check_namespace(XMLNS)?;

        let sasl_data = element.value()?;

        Ok(Challenge(sasl_data))
    }
}

#[derive(Debug)]
pub struct Success(Option<String>);

impl Deref for Success {
    type Target = Option<String>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl FromElement for Success {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("success")?;
        element.check_namespace(XMLNS)?;

        let sasl_data = element.value_opt()?;

        Ok(Success(sasl_data))
    }
}

#[derive(Debug)]
pub enum ServerResponse {
    Challenge(Challenge),
    Success(Success),
    Failure(Failure),
}

impl FromElement for ServerResponse {
    fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
        match element.identify() {
            (Some(XMLNS), "challenge") => {
                Ok(ServerResponse::Challenge(Challenge::from_element(element)?))
            }
            (Some(XMLNS), "success") => {
                Ok(ServerResponse::Success(Success::from_element(element)?))
            }
            (Some(XMLNS), "failure") => {
                Ok(ServerResponse::Failure(Failure::from_element(element)?))
            }
            _ => Err(DeserializeError::UnexpectedElement(element)),
        }
    }
}

#[derive(Debug)]
pub struct Response(String);

impl Response {
    pub fn new(response: String) -> Self {
        Self(response)
    }
}

impl Deref for Response {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl IntoElement for Response {
    fn builder(&self) -> peanuts::element::ElementBuilder {
        Element::builder("response", Some(XMLNS)).push_text(self.0.clone())
    }
}

#[derive(Error, Debug, Clone)]
pub struct Failure {
    r#type: Option<FailureType>,
    text: Option<Text>,
}

impl Display for Failure {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut had_type = false;
        let mut had_text = false;
        if let Some(r#type) = &self.r#type {
            had_type = true;
            match r#type {
                FailureType::Aborted => f.write_str("aborted"),
                FailureType::AccountDisabled => f.write_str("account disabled"),
                FailureType::CredentialsExpired => f.write_str("credentials expired"),
                FailureType::EncryptionRequired => f.write_str("encryption required"),
                FailureType::IncorrectEncoding => f.write_str("incorrect encoding"),
                FailureType::InvalidAuthzid => f.write_str("invalid authzid"),
                FailureType::InvalidMechanism => f.write_str("invalid mechanism"),
                FailureType::MalformedRequest => f.write_str("malformed request"),
                FailureType::MechanismTooWeak => f.write_str("mechanism too weak"),
                FailureType::NotAuthorized => f.write_str("not authorized"),
                FailureType::TemporaryAuthFailure => f.write_str("temporary auth failure"),
            }?;
        }
        if let Some(text) = &self.text {
            if let Some(text) = &text.text {
                if had_type {
                    f.write_str(": ")?;
                }
                f.write_str(text)?;
                had_text = true;
            }
        }
        if !had_type && !had_text {
            f.write_str("failure")?;
        }
        Ok(())
    }
}

impl FromElement for Failure {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("failure")?;
        element.check_namespace(XMLNS)?;

        let r#type = element.pop_child_opt()?;
        let text = element.pop_child_opt()?;

        Ok(Failure { r#type, text })
    }
}

#[derive(Error, Debug, Clone)]
pub enum FailureType {
    #[error("aborted")]
    Aborted,
    #[error("account disabled")]
    AccountDisabled,
    #[error("credentials expired")]
    CredentialsExpired,
    #[error("encryption required")]
    EncryptionRequired,
    #[error("incorrect encoding")]
    IncorrectEncoding,
    #[error("invalid authzid")]
    InvalidAuthzid,
    #[error("invalid mechanism")]
    InvalidMechanism,
    #[error("malformed request")]
    MalformedRequest,
    #[error("mechanism too weak")]
    MechanismTooWeak,
    #[error("not authorized")]
    NotAuthorized,
    #[error("temporary auth failure")]
    TemporaryAuthFailure,
}

impl FromElement for FailureType {
    fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
        match element.identify() {
            (Some(XMLNS), "aborted") => Ok(FailureType::Aborted),
            (Some(XMLNS), "account-disabled") => Ok(FailureType::AccountDisabled),
            (Some(XMLNS), "credentials-expired") => Ok(FailureType::CredentialsExpired),
            (Some(XMLNS), "encryption-required") => Ok(FailureType::EncryptionRequired),
            (Some(XMLNS), "incorrect-encoding") => Ok(FailureType::IncorrectEncoding),
            (Some(XMLNS), "invalid-authzid") => Ok(FailureType::InvalidAuthzid),
            (Some(XMLNS), "invalid-mechanism") => Ok(FailureType::InvalidMechanism),
            (Some(XMLNS), "malformed-request") => Ok(FailureType::MalformedRequest),
            (Some(XMLNS), "mechanism-too-weak") => Ok(FailureType::MechanismTooWeak),
            (Some(XMLNS), "not-authorized") => Ok(FailureType::NotAuthorized),
            (Some(XMLNS), "temporary-auth-failure") => Ok(FailureType::TemporaryAuthFailure),
            _ => Err(DeserializeError::UnexpectedElement(element)),
        }
    }
}

#[derive(Debug, Clone)]
pub struct Text {
    #[allow(dead_code)]
    lang: Option<String>,
    text: Option<String>,
}

impl FromElement for Text {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("text")?;
        element.check_namespace(XMLNS)?;

        let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?;

        let text = element.pop_value_opt()?;

        Ok(Text { lang, text })
    }
}