aboutsummaryrefslogblamecommitdiffstats
path: root/stanza/src/stream.rs
blob: 732a826f6c1d3e5c2e6e5dc37981d60e6048b009 (plain) (tree)
1
2
3
4
5
6
7
8
9

                      
             

                                                                 
                     
 
                
 
                  
                                    
                                      
                                                      
 
                                                           

                                        

                                          
                




                            
                      





                             
















                                                                                                   



                             


                                                                       
                                                                             




                                                                                    
     

 
                 
               




                                
               
              

                 
               
                    
                 

         
 

                                                                                                                
                                                                                             



                         
                                             
                             
         

     


                     
                               

 



























                                                                     
                               

                                                                                      



                               








                                                                   


     
                       

                       
                     


            




                                                                
                                                              











                                                                                    
                                                  

                                                                     

                                                             


         
 
                              
                  

                           

 











                                                                        


















                                                                                        
use std::fmt::Display;

use jid::JID;
use peanuts::element::{ElementBuilder, FromElement, IntoElement};
use peanuts::Element;
use thiserror::Error;

use crate::bind;

use super::client;
use super::sasl::{self, Mechanisms};
use super::starttls::{self, StartTls};
use super::stream_error::{Error as StreamError, Text};

pub const XMLNS: &str = "http://etherx.jabber.org/streams";

// MUST be qualified by stream namespace
// #[derive(XmlSerialize, XmlDeserialize)]
// #[peanuts(xmlns = XMLNS)]
#[derive(Debug)]
pub struct Stream {
    pub from: Option<JID>,
    to: Option<JID>,
    id: Option<String>,
    version: Option<String>,
    // TODO: lang enum
    lang: Option<String>,
    // #[peanuts(content)]
    // content: Message,
}

impl FromElement for Stream {
    fn from_element(mut element: Element) -> std::result::Result<Self, peanuts::DeserializeError> {
        element.check_namespace(XMLNS)?;
        element.check_name("stream")?;

        let from = element.attribute_opt("from")?;
        let to = element.attribute_opt("to")?;
        let id = element.attribute_opt("id")?;
        let version = element.attribute_opt("version")?;
        let lang = element.attribute_opt_namespaced("lang", peanuts::XML_NS)?;

        Ok(Stream {
            from,
            to,
            id,
            version,
            lang,
        })
    }
}

impl IntoElement for Stream {
    fn builder(&self) -> ElementBuilder {
        Element::builder("stream", Some(XMLNS.to_string()))
            .push_namespace_declaration_override(Some("stream"), XMLNS)
            .push_namespace_declaration_override(None::<&str>, client::XMLNS)
            .push_attribute_opt("to", self.to.clone())
            .push_attribute_opt("from", self.from.clone())
            .push_attribute_opt("id", self.id.clone())
            .push_attribute_opt("version", self.version.clone())
            .push_attribute_opt_namespaced(peanuts::XML_NS, "to", self.lang.clone())
    }
}

impl<'s> Stream {
    pub fn new(
        from: Option<JID>,
        to: Option<JID>,
        id: Option<String>,
        version: Option<String>,
        lang: Option<String>,
    ) -> Self {
        Self {
            from,
            to,
            id,
            version,
            lang,
        }
    }

    /// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute.
    /// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection.
    pub fn new_client(from: Option<JID>, to: JID, id: Option<String>, lang: String) -> Self {
        Self {
            from,
            to: Some(to),
            id,
            version: Some("1.0".to_string()),
            lang: Some(lang),
        }
    }
}

#[derive(Debug)]
pub struct Features {
    pub features: Vec<Feature>,
}

impl Features {
    pub fn negotiate(self) -> Option<Feature> {
        if let Some(Feature::StartTls(s)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::StartTls(_s)))
        {
            // TODO: avoid clone
            return Some(Feature::StartTls(s.clone()));
        } else if let Some(Feature::Sasl(mechanisms)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::Sasl(_)))
        {
            // TODO: avoid clone
            return Some(Feature::Sasl(mechanisms.clone()));
        } else if let Some(Feature::Bind) = self
            .features
            .into_iter()
            .find(|feature| matches!(feature, Feature::Bind))
        {
            Some(Feature::Bind)
        } else {
            return None;
        }
    }
}

impl IntoElement for Features {
    fn builder(&self) -> ElementBuilder {
        Element::builder("features", Some(XMLNS)).push_children(self.features.clone())
    }
}

impl FromElement for Features {
    fn from_element(
        mut element: Element,
    ) -> std::result::Result<Features, peanuts::DeserializeError> {
        element.check_namespace(XMLNS)?;
        element.check_name("features")?;

        let features = element.children()?;

        Ok(Features { features })
    }
}

#[derive(Debug, Clone)]
pub enum Feature {
    StartTls(StartTls),
    Sasl(Mechanisms),
    Bind,
    Unknown,
}

impl IntoElement for Feature {
    fn builder(&self) -> ElementBuilder {
        match self {
            Feature::StartTls(start_tls) => start_tls.builder(),
            Feature::Sasl(mechanisms) => mechanisms.builder(),
            Feature::Bind => todo!(),
            Feature::Unknown => todo!(),
        }
    }
}

impl FromElement for Feature {
    fn from_element(element: Element) -> peanuts::element::DeserializeResult<Self> {
        match element.identify() {
            (Some(starttls::XMLNS), "starttls") => {
                Ok(Feature::StartTls(StartTls::from_element(element)?))
            }
            (Some(sasl::XMLNS), "mechanisms") => {
                Ok(Feature::Sasl(Mechanisms::from_element(element)?))
            }
            (Some(bind::XMLNS), "bind") => Ok(Feature::Bind),
            _ => Ok(Feature::Unknown),
        }
    }
}

#[derive(Error, Debug, Clone)]
pub struct Error {
    pub error: StreamError,
    pub text: Option<Text>,
}

impl Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.error)?;
        if let Some(text) = &self.text {
            if let Some(text) = &text.text {
                write!(f, ": {}", text)?;
            }
        }
        Ok(())
    }
}

impl FromElement for Error {
    fn from_element(mut element: Element) -> peanuts::element::DeserializeResult<Self> {
        element.check_name("error")?;
        element.check_namespace(XMLNS)?;

        let error = element.pop_child_one()?;
        let text = element.pop_child_opt()?;

        Ok(Error { error, text })
    }
}

impl IntoElement for Error {
    fn builder(&self) -> ElementBuilder {
        Element::builder("error", Some(XMLNS))
            .push_child(self.error.clone())
            .push_child_opt(self.text.clone())
    }
}