aboutsummaryrefslogblamecommitdiffstats
path: root/jid/src/lib.rs
blob: 543d1bac9358156139b7605f39947d0d168e5a18 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                                                    
 

                 
                                                                      
                
                                              


                                     

 


























                                                                   




                           
                
                     
          


                      












                                                                                        
          




                                     





                                                    















                                                      


                      
                          



                                                      
                                        








                                                                        
                                                                   














                                                                      
                                                                   

                 
                                                           




                              
                            





                                                             







                                                           





                                                                        
                            























































                                                                             
use std::{error::Error, fmt::Display, str::FromStr};

use sqlx::Sqlite;

#[derive(PartialEq, Debug, Clone, sqlx::Type, sqlx::Encode, Eq, Hash)]
pub struct JID {
    // TODO: validate localpart (length, char]
    pub localpart: Option<String>,
    pub domainpart: String,
    pub resourcepart: Option<String>,
}

// TODO: feature gate
impl sqlx::Type<Sqlite> for JID {
    fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
        <&str as sqlx::Type<Sqlite>>::type_info()
    }
}

impl sqlx::Decode<'_, Sqlite> for JID {
    fn decode(
        value: <Sqlite as sqlx::Database>::ValueRef<'_>,
    ) -> Result<Self, sqlx::error::BoxDynError> {
        let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;

        Ok(value.parse()?)
    }
}

impl sqlx::Encode<'_, Sqlite> for JID {
    fn encode_by_ref(
        &self,
        buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>,
    ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
        let jid = self.to_string();
        <String as sqlx::Encode<Sqlite>>::encode(jid, buf)
    }
}

pub enum JIDError {
    NoResourcePart,
    ParseError(ParseError),
}

#[derive(Debug)]
pub enum ParseError {
    Empty,
    Malformed(String),
}

impl Display for ParseError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            ParseError::Empty => f.write_str("JID parse error: Empty"),
            ParseError::Malformed(j) => {
                f.write_str(format!("JID parse error: malformed; got '{}'", j).as_str())
            }
        }
    }
}

impl Error for ParseError {}

impl JID {
    pub fn new(
        localpart: Option<String>,
        domainpart: String,
        resourcepart: Option<String>,
    ) -> Self {
        Self {
            localpart,
            domainpart: domainpart.parse().unwrap(),
            resourcepart,
        }
    }

    pub fn as_bare(&self) -> Self {
        Self {
            localpart: self.localpart.clone(),
            domainpart: self.domainpart.clone(),
            resourcepart: None,
        }
    }

    pub fn as_full(&self) -> Result<&Self, JIDError> {
        if let Some(_) = self.resourcepart {
            Ok(&self)
        } else {
            Err(JIDError::NoResourcePart)
        }
    }
}

impl FromStr for JID {
    type Err = ParseError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let split: Vec<&str> = s.split('@').collect();
        match split.len() {
            0 => Err(ParseError::Empty),
            1 => {
                let split: Vec<&str> = split[0].split('/').collect();
                match split.len() {
                    1 => Ok(JID::new(None, split[0].to_string(), None)),
                    2 => Ok(JID::new(
                        None,
                        split[0].to_string(),
                        Some(split[1].to_string()),
                    )),
                    _ => Err(ParseError::Malformed(s.to_string())),
                }
            }
            2 => {
                let split2: Vec<&str> = split[1].split('/').collect();
                match split2.len() {
                    1 => Ok(JID::new(
                        Some(split[0].to_string()),
                        split2[0].to_string(),
                        None,
                    )),
                    2 => Ok(JID::new(
                        Some(split[0].to_string()),
                        split2[0].to_string(),
                        Some(split2[1].to_string()),
                    )),
                    _ => Err(ParseError::Malformed(s.to_string())),
                }
            }
            _ => Err(ParseError::Malformed(s.to_string())),
        }
    }
}

impl TryFrom<String> for JID {
    type Error = ParseError;

    fn try_from(value: String) -> Result<Self, Self::Error> {
        value.parse()
    }
}

impl TryFrom<&str> for JID {
    type Error = ParseError;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        value.parse()
    }
}

impl std::fmt::Display for JID {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "{}{}{}",
            self.localpart.clone().map(|l| l + "@").unwrap_or_default(),
            self.domainpart,
            self.resourcepart
                .clone()
                .map(|r| "/".to_owned() + &r)
                .unwrap_or_default()
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn jid_to_string() {
        assert_eq!(
            JID::new(Some("cel".into()), "blos.sm".into(), None).to_string(),
            "cel@blos.sm".to_owned()
        );
    }

    #[test]
    fn parse_full_jid() {
        assert_eq!(
            "cel@blos.sm/greenhouse".parse::<JID>().unwrap(),
            JID::new(
                Some("cel".into()),
                "blos.sm".into(),
                Some("greenhouse".into())
            )
        )
    }

    #[test]
    fn parse_bare_jid() {
        assert_eq!(
            "cel@blos.sm".parse::<JID>().unwrap(),
            JID::new(Some("cel".into()), "blos.sm".into(), None)
        )
    }

    #[test]
    fn parse_domain_jid() {
        assert_eq!(
            "component.blos.sm".parse::<JID>().unwrap(),
            JID::new(None, "component.blos.sm".into(), None)
        )
    }

    #[test]
    fn parse_full_domain_jid() {
        assert_eq!(
            "component.blos.sm/bot".parse::<JID>().unwrap(),
            JID::new(None, "component.blos.sm".into(), Some("bot".into()))
        )
    }
}