diff options
Diffstat (limited to 'jid/src/lib.rs')
-rw-r--r-- | jid/src/lib.rs | 291 |
1 files changed, 238 insertions, 53 deletions
diff --git a/jid/src/lib.rs b/jid/src/lib.rs index 47ca497..3b40094 100644 --- a/jid/src/lib.rs +++ b/jid/src/lib.rs @@ -1,15 +1,53 @@ -use std::{borrow::Cow, error::Error, fmt::Display, str::FromStr}; +use std::{borrow::Cow, error::Error, fmt::Display, ops::Deref, str::FromStr}; // #[cfg(feature = "sqlx")] // use sqlx::Sqlite; -#[derive(PartialEq, Debug, Clone, Eq, Hash)] +#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct JID { - // TODO: validate localpart (length, char] - pub localpart: Option<String>, - pub domainpart: String, - pub resourcepart: Option<String>, +pub enum JID { + Full(FullJID), + Bare(BareJID), +} + +impl JID { + pub fn resourcepart(&self) -> Option<&String> { + match self { + JID::Full(full_jid) => Some(&full_jid.resourcepart), + JID::Bare(_bare_jid) => None, + } + } +} + +impl From<FullJID> for JID { + fn from(value: FullJID) -> Self { + Self::Full(value) + } +} + +impl From<BareJID> for JID { + fn from(value: BareJID) -> Self { + Self::Bare(value) + } +} + +impl Deref for JID { + type Target = BareJID; + + fn deref(&self) -> &Self::Target { + match self { + JID::Full(full_jid) => full_jid.as_bare(), + JID::Bare(bare_jid) => bare_jid, + } + } +} + +impl Deref for FullJID { + type Target = BareJID; + + fn deref(&self) -> &Self::Target { + &self.bare_jid + } } impl<'a> Into<Cow<'a, str>> for &'a JID { @@ -21,15 +59,45 @@ impl<'a> Into<Cow<'a, str>> for &'a JID { impl Display for JID { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JID::Full(full_jid) => full_jid.fmt(f), + JID::Bare(bare_jid) => bare_jid.fmt(f), + } + } +} + +#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct FullJID { + pub bare_jid: BareJID, + pub resourcepart: String, +} + +impl Display for FullJID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.bare_jid.fmt(f)?; + f.write_str("/")?; + f.write_str(&self.resourcepart)?; + Ok(()) + } +} + +#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct BareJID { + // TODO: validate and don't have public fields + // TODO: validate localpart (length, char] + pub localpart: Option<String>, + pub domainpart: String, +} + +impl Display for BareJID { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(localpart) = &self.localpart { f.write_str(localpart)?; f.write_str("@")?; } f.write_str(&self.domainpart)?; - if let Some(resourcepart) = &self.resourcepart { - f.write_str("/")?; - f.write_str(resourcepart)?; - } Ok(()) } } @@ -51,52 +119,65 @@ impl rusqlite::types::FromSql for JID { } #[cfg(feature = "rusqlite")] -impl From<ParseError> for rusqlite::types::FromSqlError { - fn from(value: ParseError) -> Self { - Self::Other(Box::new(value)) +impl rusqlite::ToSql for FullJID { + fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> { + Ok(rusqlite::types::ToSqlOutput::Owned( + rusqlite::types::Value::Text(self.to_string()), + )) } } -// #[cfg(feature = "sqlx")] -// impl sqlx::Type<Sqlite> for JID { -// fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo { -// <&str as sqlx::Type<Sqlite>>::type_info() -// } -// } +#[cfg(feature = "rusqlite")] +impl rusqlite::types::FromSql for FullJID { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> { + Ok(JID::from_str(value.as_str()?)?.try_into()?) + } +} -// #[cfg(feature = "sqlx")] -// impl sqlx::Decode<'_, Sqlite> for JID { -// fn decode( -// value: <Sqlite as sqlx::Database>::ValueRef<'_>, -// ) -> Result<Self, sqlx::error::BoxDynError> { -// let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?; +#[cfg(feature = "rusqlite")] +impl rusqlite::ToSql for BareJID { + fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> { + Ok(rusqlite::types::ToSqlOutput::Owned( + rusqlite::types::Value::Text(self.to_string()), + )) + } +} -// Ok(value.parse()?) -// } -// } +#[cfg(feature = "rusqlite")] +impl rusqlite::types::FromSql for BareJID { + fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> { + Ok(JID::from_str(value.as_str()?)?.try_into()?) + } +} -// #[cfg(feature = "sqlx")] -// impl sqlx::Encode<'_, Sqlite> for JID { -// fn encode_by_ref( -// &self, -// buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>, -// ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { -// let jid = self.to_string(); -// <String as sqlx::Encode<Sqlite>>::encode(jid, buf) -// } -// } +#[cfg(feature = "rusqlite")] +impl From<ParseError> for rusqlite::types::FromSqlError { + fn from(value: ParseError) -> Self { + Self::Other(Box::new(value)) + } +} + +#[cfg(feature = "rusqlite")] +impl From<JIDError> for rusqlite::types::FromSqlError { + fn from(value: JIDError) -> Self { + Self::Other(Box::new(value)) + } +} #[derive(Debug, Clone)] pub enum JIDError { NoResourcePart, ParseError(ParseError), + ContainsResourcepart, } impl Display for JIDError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + // TODO: separate jid errors? JIDError::NoResourcePart => f.write_str("resourcepart missing"), JIDError::ParseError(parse_error) => parse_error.fmt(f), + JIDError::ContainsResourcepart => f.write_str("contains resourcepart"), } } } @@ -122,36 +203,105 @@ impl Display for ParseError { impl Error for ParseError {} +impl FullJID { + pub fn new(localpart: Option<String>, domainpart: String, resourcepart: String) -> Self { + Self { + bare_jid: BareJID::new(localpart, domainpart), + resourcepart, + } + } + + pub fn as_bare(&self) -> &BareJID { + &self.bare_jid + } + + pub fn to_bare(&self) -> BareJID { + self.bare_jid.clone() + } +} + +impl BareJID { + pub fn new(localpart: Option<String>, domainpart: String) -> Self { + Self { + localpart, + domainpart, + } + } +} + +impl TryFrom<JID> for BareJID { + type Error = JIDError; + + fn try_from(value: JID) -> Result<Self, Self::Error> { + match value { + JID::Full(_full_jid) => Err(JIDError::ContainsResourcepart), + JID::Bare(bare_jid) => Ok(bare_jid), + } + } +} + impl JID { pub fn new( localpart: Option<String>, domainpart: String, resourcepart: Option<String>, ) -> Self { - Self { - localpart, - domainpart: domainpart.parse().unwrap(), - resourcepart, + if let Some(resourcepart) = resourcepart { + Self::Full(FullJID::new(localpart, domainpart, resourcepart)) + } else { + Self::Bare(BareJID::new(localpart, domainpart)) } } - pub fn as_bare(&self) -> Self { - Self { - localpart: self.localpart.clone(), - domainpart: self.domainpart.clone(), - resourcepart: None, + pub fn as_bare(&self) -> &BareJID { + match self { + JID::Full(full_jid) => full_jid.as_bare(), + JID::Bare(bare_jid) => &bare_jid, } } - pub fn as_full(&self) -> Result<&Self, JIDError> { - if let Some(_) = self.resourcepart { - Ok(&self) - } else { - Err(JIDError::NoResourcePart) + pub fn to_bare(&self) -> BareJID { + match self { + JID::Full(full_jid) => full_jid.to_bare(), + JID::Bare(bare_jid) => bare_jid.clone(), + } + } + + pub fn as_full(&self) -> Result<&FullJID, JIDError> { + match self { + JID::Full(full_jid) => Ok(full_jid), + JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart), + } + } +} + +impl TryFrom<JID> for FullJID { + type Error = JIDError; + + fn try_from(value: JID) -> Result<Self, Self::Error> { + match value { + JID::Full(full_jid) => Ok(full_jid), + JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart), } } } +impl FromStr for BareJID { + type Err = JIDError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + Ok(JID::from_str(s)?.try_into()?) + } +} + +impl FromStr for FullJID { + type Err = JIDError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + Ok(JID::from_str(s)?.try_into()?) + } +} + impl FromStr for JID { type Err = ParseError; @@ -192,6 +342,12 @@ impl FromStr for JID { } } +impl From<ParseError> for JIDError { + fn from(value: ParseError) -> Self { + JIDError::ParseError(value) + } +} + impl TryFrom<String> for JID { type Error = ParseError; @@ -256,3 +412,32 @@ mod tests { ) } } + +// #[cfg(feature = "sqlx")] +// impl sqlx::Type<Sqlite> for JID { +// fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo { +// <&str as sqlx::Type<Sqlite>>::type_info() +// } +// } + +// #[cfg(feature = "sqlx")] +// impl sqlx::Decode<'_, Sqlite> for JID { +// fn decode( +// value: <Sqlite as sqlx::Database>::ValueRef<'_>, +// ) -> Result<Self, sqlx::error::BoxDynError> { +// let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?; + +// Ok(value.parse()?) +// } +// } + +// #[cfg(feature = "sqlx")] +// impl sqlx::Encode<'_, Sqlite> for JID { +// fn encode_by_ref( +// &self, +// buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>, +// ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { +// let jid = self.to_string(); +// <String as sqlx::Encode<Sqlite>>::encode(jid, buf) +// } +// } |