aboutsummaryrefslogblamecommitdiffstats
path: root/jid/src/lib.rs
blob: 3b40094cabec1961a4b8a605f748773affbf9048 (plain) (tree)
1
2
3
4
5
6
7
                                                                             
 

                           
 
                                                             
                                                                            










































                                                                

 








                                                                        

































                                                                            




                                                  



              
















                                                                                                    




                                                                            


     





                                                                                                    
 







                                                                            
 





                                                                                                    
 












                                                         
 
                       


                           
                         

 


                                                                        
                                         

                                                                            
                                                                                   





                          
                       
                     
          


                      












                                                                                        




































                                                                                             
          




                                     



                                                                         

         
 



                                                      


         





















                                                                  

         

 















                                                     
                      
                          



                                                      
                                        








                                                                        
                                                                   














                                                                      
                                                                   

                 
                                                           



         





                                        
                              
                            





                                                             







                                                           















































                                                                             




























                                                                      
use std::{borrow::Cow, error::Error, fmt::Display, ops::Deref, str::FromStr};

// #[cfg(feature = "sqlx")]
// use sqlx::Sqlite;

#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum JID {
    Full(FullJID),
    Bare(BareJID),
}

impl JID {
    pub fn resourcepart(&self) -> Option<&String> {
        match self {
            JID::Full(full_jid) => Some(&full_jid.resourcepart),
            JID::Bare(_bare_jid) => None,
        }
    }
}

impl From<FullJID> for JID {
    fn from(value: FullJID) -> Self {
        Self::Full(value)
    }
}

impl From<BareJID> for JID {
    fn from(value: BareJID) -> Self {
        Self::Bare(value)
    }
}

impl Deref for JID {
    type Target = BareJID;

    fn deref(&self) -> &Self::Target {
        match self {
            JID::Full(full_jid) => full_jid.as_bare(),
            JID::Bare(bare_jid) => bare_jid,
        }
    }
}

impl Deref for FullJID {
    type Target = BareJID;

    fn deref(&self) -> &Self::Target {
        &self.bare_jid
    }
}

impl<'a> Into<Cow<'a, str>> for &'a JID {
    fn into(self) -> Cow<'a, str> {
        let a = self.to_string();
        Cow::Owned(a)
    }
}

impl Display for JID {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            JID::Full(full_jid) => full_jid.fmt(f),
            JID::Bare(bare_jid) => bare_jid.fmt(f),
        }
    }
}

#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FullJID {
    pub bare_jid: BareJID,
    pub resourcepart: String,
}

impl Display for FullJID {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.bare_jid.fmt(f)?;
        f.write_str("/")?;
        f.write_str(&self.resourcepart)?;
        Ok(())
    }
}

#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BareJID {
    // TODO: validate and don't have public fields
    // TODO: validate localpart (length, char]
    pub localpart: Option<String>,
    pub domainpart: String,
}

impl Display for BareJID {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if let Some(localpart) = &self.localpart {
            f.write_str(localpart)?;
            f.write_str("@")?;
        }
        f.write_str(&self.domainpart)?;
        Ok(())
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::ToSql for JID {
    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
        Ok(rusqlite::types::ToSqlOutput::Owned(
            rusqlite::types::Value::Text(self.to_string()),
        ))
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::types::FromSql for JID {
    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
        Ok(JID::from_str(value.as_str()?)?)
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::ToSql for FullJID {
    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
        Ok(rusqlite::types::ToSqlOutput::Owned(
            rusqlite::types::Value::Text(self.to_string()),
        ))
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::types::FromSql for FullJID {
    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
        Ok(JID::from_str(value.as_str()?)?.try_into()?)
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::ToSql for BareJID {
    fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
        Ok(rusqlite::types::ToSqlOutput::Owned(
            rusqlite::types::Value::Text(self.to_string()),
        ))
    }
}

#[cfg(feature = "rusqlite")]
impl rusqlite::types::FromSql for BareJID {
    fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
        Ok(JID::from_str(value.as_str()?)?.try_into()?)
    }
}

#[cfg(feature = "rusqlite")]
impl From<ParseError> for rusqlite::types::FromSqlError {
    fn from(value: ParseError) -> Self {
        Self::Other(Box::new(value))
    }
}

#[cfg(feature = "rusqlite")]
impl From<JIDError> for rusqlite::types::FromSqlError {
    fn from(value: JIDError) -> Self {
        Self::Other(Box::new(value))
    }
}

#[derive(Debug, Clone)]
pub enum JIDError {
    NoResourcePart,
    ParseError(ParseError),
    ContainsResourcepart,
}

impl Display for JIDError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            // TODO: separate jid errors?
            JIDError::NoResourcePart => f.write_str("resourcepart missing"),
            JIDError::ParseError(parse_error) => parse_error.fmt(f),
            JIDError::ContainsResourcepart => f.write_str("contains resourcepart"),
        }
    }
}

impl Error for JIDError {}

#[derive(Debug, Clone)]
pub enum ParseError {
    Empty,
    Malformed(String),
}

impl Display for ParseError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ParseError::Empty => f.write_str("JID parse error: Empty"),
            ParseError::Malformed(j) => {
                f.write_str(format!("JID parse error: malformed; got '{}'", j).as_str())
            }
        }
    }
}

impl Error for ParseError {}

impl FullJID {
    pub fn new(localpart: Option<String>, domainpart: String, resourcepart: String) -> Self {
        Self {
            bare_jid: BareJID::new(localpart, domainpart),
            resourcepart,
        }
    }

    pub fn as_bare(&self) -> &BareJID {
        &self.bare_jid
    }

    pub fn to_bare(&self) -> BareJID {
        self.bare_jid.clone()
    }
}

impl BareJID {
    pub fn new(localpart: Option<String>, domainpart: String) -> Self {
        Self {
            localpart,
            domainpart,
        }
    }
}

impl TryFrom<JID> for BareJID {
    type Error = JIDError;

    fn try_from(value: JID) -> Result<Self, Self::Error> {
        match value {
            JID::Full(_full_jid) => Err(JIDError::ContainsResourcepart),
            JID::Bare(bare_jid) => Ok(bare_jid),
        }
    }
}

impl JID {
    pub fn new(
        localpart: Option<String>,
        domainpart: String,
        resourcepart: Option<String>,
    ) -> Self {
        if let Some(resourcepart) = resourcepart {
            Self::Full(FullJID::new(localpart, domainpart, resourcepart))
        } else {
            Self::Bare(BareJID::new(localpart, domainpart))
        }
    }

    pub fn as_bare(&self) -> &BareJID {
        match self {
            JID::Full(full_jid) => full_jid.as_bare(),
            JID::Bare(bare_jid) => &bare_jid,
        }
    }

    pub fn to_bare(&self) -> BareJID {
        match self {
            JID::Full(full_jid) => full_jid.to_bare(),
            JID::Bare(bare_jid) => bare_jid.clone(),
        }
    }

    pub fn as_full(&self) -> Result<&FullJID, JIDError> {
        match self {
            JID::Full(full_jid) => Ok(full_jid),
            JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart),
        }
    }
}

impl TryFrom<JID> for FullJID {
    type Error = JIDError;

    fn try_from(value: JID) -> Result<Self, Self::Error> {
        match value {
            JID::Full(full_jid) => Ok(full_jid),
            JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart),
        }
    }
}

impl FromStr for BareJID {
    type Err = JIDError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(JID::from_str(s)?.try_into()?)
    }
}

impl FromStr for FullJID {
    type Err = JIDError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        Ok(JID::from_str(s)?.try_into()?)
    }
}

impl FromStr for JID {
    type Err = ParseError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let split: Vec<&str> = s.split('@').collect();
        match split.len() {
            0 => Err(ParseError::Empty),
            1 => {
                let split: Vec<&str> = split[0].split('/').collect();
                match split.len() {
                    1 => Ok(JID::new(None, split[0].to_string(), None)),
                    2 => Ok(JID::new(
                        None,
                        split[0].to_string(),
                        Some(split[1].to_string()),
                    )),
                    _ => Err(ParseError::Malformed(s.to_string())),
                }
            }
            2 => {
                let split2: Vec<&str> = split[1].split('/').collect();
                match split2.len() {
                    1 => Ok(JID::new(
                        Some(split[0].to_string()),
                        split2[0].to_string(),
                        None,
                    )),
                    2 => Ok(JID::new(
                        Some(split[0].to_string()),
                        split2[0].to_string(),
                        Some(split2[1].to_string()),
                    )),
                    _ => Err(ParseError::Malformed(s.to_string())),
                }
            }
            _ => Err(ParseError::Malformed(s.to_string())),
        }
    }
}

impl From<ParseError> for JIDError {
    fn from(value: ParseError) -> Self {
        JIDError::ParseError(value)
    }
}

impl TryFrom<String> for JID {
    type Error = ParseError;

    fn try_from(value: String) -> Result<Self, Self::Error> {
        value.parse()
    }
}

impl TryFrom<&str> for JID {
    type Error = ParseError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        value.parse()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn jid_to_string() {
        assert_eq!(
            JID::new(Some("cel".into()), "blos.sm".into(), None).to_string(),
            "cel@blos.sm".to_owned()
        );
    }

    #[test]
    fn parse_full_jid() {
        assert_eq!(
            "cel@blos.sm/greenhouse".parse::<JID>().unwrap(),
            JID::new(
                Some("cel".into()),
                "blos.sm".into(),
                Some("greenhouse".into())
            )
        )
    }

    #[test]
    fn parse_bare_jid() {
        assert_eq!(
            "cel@blos.sm".parse::<JID>().unwrap(),
            JID::new(Some("cel".into()), "blos.sm".into(), None)
        )
    }

    #[test]
    fn parse_domain_jid() {
        assert_eq!(
            "component.blos.sm".parse::<JID>().unwrap(),
            JID::new(None, "component.blos.sm".into(), None)
        )
    }

    #[test]
    fn parse_full_domain_jid() {
        assert_eq!(
            "component.blos.sm/bot".parse::<JID>().unwrap(),
            JID::new(None, "component.blos.sm".into(), Some("bot".into()))
        )
    }
}

// #[cfg(feature = "sqlx")]
// impl sqlx::Type<Sqlite> for JID {
//     fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
//         <&str as sqlx::Type<Sqlite>>::type_info()
//     }
// }

// #[cfg(feature = "sqlx")]
// impl sqlx::Decode<'_, Sqlite> for JID {
//     fn decode(
//         value: <Sqlite as sqlx::Database>::ValueRef<'_>,
//     ) -> Result<Self, sqlx::error::BoxDynError> {
//         let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;

//         Ok(value.parse()?)
//     }
// }

// #[cfg(feature = "sqlx")]
// impl sqlx::Encode<'_, Sqlite> for JID {
//     fn encode_by_ref(
//         &self,
//         buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>,
//     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
//         let jid = self.to_string();
//         <String as sqlx::Encode<Sqlite>>::encode(jid, buf)
//     }
// }