use std::{error::Error, fmt::Display, str::FromStr}; use sqlx::Sqlite; #[derive(PartialEq, Debug, Clone, sqlx::Type, sqlx::Encode, Eq, Hash)] pub struct JID { // TODO: validate localpart (length, char] pub localpart: Option, pub domainpart: String, pub resourcepart: Option, } // TODO: feature gate impl sqlx::Type for JID { fn type_info() -> ::TypeInfo { <&str as sqlx::Type>::type_info() } } impl sqlx::Decode<'_, Sqlite> for JID { fn decode( value: ::ValueRef<'_>, ) -> Result { let value = <&str as sqlx::Decode>::decode(value)?; Ok(value.parse()?) } } impl sqlx::Encode<'_, Sqlite> for JID { fn encode_by_ref( &self, buf: &mut ::ArgumentBuffer<'_>, ) -> Result { let jid = self.to_string(); >::encode(jid, buf) } } pub enum JIDError { NoResourcePart, ParseError(ParseError), } #[derive(Debug)] pub enum ParseError { Empty, Malformed(String), } impl Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ParseError::Empty => f.write_str("JID parse error: Empty"), ParseError::Malformed(j) => { f.write_str(format!("JID parse error: malformed; got '{}'", j).as_str()) } } } } impl Error for ParseError {} impl JID { pub fn new( localpart: Option, domainpart: String, resourcepart: Option, ) -> Self { Self { localpart, domainpart: domainpart.parse().unwrap(), resourcepart, } } pub fn as_bare(&self) -> Self { Self { localpart: self.localpart.clone(), domainpart: self.domainpart.clone(), resourcepart: None, } } pub fn as_full(&self) -> Result<&Self, JIDError> { if let Some(_) = self.resourcepart { Ok(&self) } else { Err(JIDError::NoResourcePart) } } } impl FromStr for JID { type Err = ParseError; fn from_str(s: &str) -> Result { let split: Vec<&str> = s.split('@').collect(); match split.len() { 0 => Err(ParseError::Empty), 1 => { let split: Vec<&str> = split[0].split('/').collect(); match split.len() { 1 => Ok(JID::new(None, split[0].to_string(), None)), 2 => Ok(JID::new( None, split[0].to_string(), Some(split[1].to_string()), )), _ => Err(ParseError::Malformed(s.to_string())), } } 2 => { let split2: Vec<&str> = split[1].split('/').collect(); match split2.len() { 1 => Ok(JID::new( Some(split[0].to_string()), split2[0].to_string(), None, )), 2 => Ok(JID::new( Some(split[0].to_string()), split2[0].to_string(), Some(split2[1].to_string()), )), _ => Err(ParseError::Malformed(s.to_string())), } } _ => Err(ParseError::Malformed(s.to_string())), } } } impl TryFrom for JID { type Error = ParseError; fn try_from(value: String) -> Result { value.parse() } } impl TryFrom<&str> for JID { type Error = ParseError; fn try_from(value: &str) -> Result { value.parse() } } impl std::fmt::Display for JID { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "{}{}{}", self.localpart.clone().map(|l| l + "@").unwrap_or_default(), self.domainpart, self.resourcepart .clone() .map(|r| "/".to_owned() + &r) .unwrap_or_default() ) } } #[cfg(test)] mod tests { use super::*; #[test] fn jid_to_string() { assert_eq!( JID::new(Some("cel".into()), "blos.sm".into(), None).to_string(), "cel@blos.sm".to_owned() ); } #[test] fn parse_full_jid() { assert_eq!( "cel@blos.sm/greenhouse".parse::().unwrap(), JID::new( Some("cel".into()), "blos.sm".into(), Some("greenhouse".into()) ) ) } #[test] fn parse_bare_jid() { assert_eq!( "cel@blos.sm".parse::().unwrap(), JID::new(Some("cel".into()), "blos.sm".into(), None) ) } #[test] fn parse_domain_jid() { assert_eq!( "component.blos.sm".parse::().unwrap(), JID::new(None, "component.blos.sm".into(), None) ) } #[test] fn parse_full_domain_jid() { assert_eq!( "component.blos.sm/bot".parse::().unwrap(), JID::new(None, "component.blos.sm".into(), Some("bot".into())) ) } }