aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2025-05-31 16:38:42 +0100
committerLibravatar cel 🌸 <cel@bunny.garden>2025-06-01 13:31:30 +0100
commitc196aecbe7b3b29fbfff9e997478688e6833b7f3 (patch)
tree74d0854a6498b43b3f678681c0ac5a6649ae577c
parent28ab91329f0df964d6f960eda3b216a5df0921c9 (diff)
downloadluz-c196aecbe7b3b29fbfff9e997478688e6833b7f3.tar.gz
luz-c196aecbe7b3b29fbfff9e997478688e6833b7f3.tar.bz2
luz-c196aecbe7b3b29fbfff9e997478688e6833b7f3.zip
refactor(jid): more type safety
-rw-r--r--jid/src/lib.rs291
1 files changed, 238 insertions, 53 deletions
diff --git a/jid/src/lib.rs b/jid/src/lib.rs
index 47ca497..3b40094 100644
--- a/jid/src/lib.rs
+++ b/jid/src/lib.rs
@@ -1,15 +1,53 @@
-use std::{borrow::Cow, error::Error, fmt::Display, str::FromStr};
+use std::{borrow::Cow, error::Error, fmt::Display, ops::Deref, str::FromStr};
// #[cfg(feature = "sqlx")]
// use sqlx::Sqlite;
-#[derive(PartialEq, Debug, Clone, Eq, Hash)]
+#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
-pub struct JID {
- // TODO: validate localpart (length, char]
- pub localpart: Option<String>,
- pub domainpart: String,
- pub resourcepart: Option<String>,
+pub enum JID {
+ Full(FullJID),
+ Bare(BareJID),
+}
+
+impl JID {
+ pub fn resourcepart(&self) -> Option<&String> {
+ match self {
+ JID::Full(full_jid) => Some(&full_jid.resourcepart),
+ JID::Bare(_bare_jid) => None,
+ }
+ }
+}
+
+impl From<FullJID> for JID {
+ fn from(value: FullJID) -> Self {
+ Self::Full(value)
+ }
+}
+
+impl From<BareJID> for JID {
+ fn from(value: BareJID) -> Self {
+ Self::Bare(value)
+ }
+}
+
+impl Deref for JID {
+ type Target = BareJID;
+
+ fn deref(&self) -> &Self::Target {
+ match self {
+ JID::Full(full_jid) => full_jid.as_bare(),
+ JID::Bare(bare_jid) => bare_jid,
+ }
+ }
+}
+
+impl Deref for FullJID {
+ type Target = BareJID;
+
+ fn deref(&self) -> &Self::Target {
+ &self.bare_jid
+ }
}
impl<'a> Into<Cow<'a, str>> for &'a JID {
@@ -21,15 +59,45 @@ impl<'a> Into<Cow<'a, str>> for &'a JID {
impl Display for JID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ JID::Full(full_jid) => full_jid.fmt(f),
+ JID::Bare(bare_jid) => bare_jid.fmt(f),
+ }
+ }
+}
+
+#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct FullJID {
+ pub bare_jid: BareJID,
+ pub resourcepart: String,
+}
+
+impl Display for FullJID {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.bare_jid.fmt(f)?;
+ f.write_str("/")?;
+ f.write_str(&self.resourcepart)?;
+ Ok(())
+ }
+}
+
+#[derive(PartialEq, Debug, Clone, Eq, Hash, PartialOrd, Ord)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct BareJID {
+ // TODO: validate and don't have public fields
+ // TODO: validate localpart (length, char]
+ pub localpart: Option<String>,
+ pub domainpart: String,
+}
+
+impl Display for BareJID {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(localpart) = &self.localpart {
f.write_str(localpart)?;
f.write_str("@")?;
}
f.write_str(&self.domainpart)?;
- if let Some(resourcepart) = &self.resourcepart {
- f.write_str("/")?;
- f.write_str(resourcepart)?;
- }
Ok(())
}
}
@@ -51,52 +119,65 @@ impl rusqlite::types::FromSql for JID {
}
#[cfg(feature = "rusqlite")]
-impl From<ParseError> for rusqlite::types::FromSqlError {
- fn from(value: ParseError) -> Self {
- Self::Other(Box::new(value))
+impl rusqlite::ToSql for FullJID {
+ fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
+ Ok(rusqlite::types::ToSqlOutput::Owned(
+ rusqlite::types::Value::Text(self.to_string()),
+ ))
}
}
-// #[cfg(feature = "sqlx")]
-// impl sqlx::Type<Sqlite> for JID {
-// fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
-// <&str as sqlx::Type<Sqlite>>::type_info()
-// }
-// }
+#[cfg(feature = "rusqlite")]
+impl rusqlite::types::FromSql for FullJID {
+ fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
+ Ok(JID::from_str(value.as_str()?)?.try_into()?)
+ }
+}
-// #[cfg(feature = "sqlx")]
-// impl sqlx::Decode<'_, Sqlite> for JID {
-// fn decode(
-// value: <Sqlite as sqlx::Database>::ValueRef<'_>,
-// ) -> Result<Self, sqlx::error::BoxDynError> {
-// let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;
+#[cfg(feature = "rusqlite")]
+impl rusqlite::ToSql for BareJID {
+ fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
+ Ok(rusqlite::types::ToSqlOutput::Owned(
+ rusqlite::types::Value::Text(self.to_string()),
+ ))
+ }
+}
-// Ok(value.parse()?)
-// }
-// }
+#[cfg(feature = "rusqlite")]
+impl rusqlite::types::FromSql for BareJID {
+ fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
+ Ok(JID::from_str(value.as_str()?)?.try_into()?)
+ }
+}
-// #[cfg(feature = "sqlx")]
-// impl sqlx::Encode<'_, Sqlite> for JID {
-// fn encode_by_ref(
-// &self,
-// buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>,
-// ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
-// let jid = self.to_string();
-// <String as sqlx::Encode<Sqlite>>::encode(jid, buf)
-// }
-// }
+#[cfg(feature = "rusqlite")]
+impl From<ParseError> for rusqlite::types::FromSqlError {
+ fn from(value: ParseError) -> Self {
+ Self::Other(Box::new(value))
+ }
+}
+
+#[cfg(feature = "rusqlite")]
+impl From<JIDError> for rusqlite::types::FromSqlError {
+ fn from(value: JIDError) -> Self {
+ Self::Other(Box::new(value))
+ }
+}
#[derive(Debug, Clone)]
pub enum JIDError {
NoResourcePart,
ParseError(ParseError),
+ ContainsResourcepart,
}
impl Display for JIDError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
+ // TODO: separate jid errors?
JIDError::NoResourcePart => f.write_str("resourcepart missing"),
JIDError::ParseError(parse_error) => parse_error.fmt(f),
+ JIDError::ContainsResourcepart => f.write_str("contains resourcepart"),
}
}
}
@@ -122,36 +203,105 @@ impl Display for ParseError {
impl Error for ParseError {}
+impl FullJID {
+ pub fn new(localpart: Option<String>, domainpart: String, resourcepart: String) -> Self {
+ Self {
+ bare_jid: BareJID::new(localpart, domainpart),
+ resourcepart,
+ }
+ }
+
+ pub fn as_bare(&self) -> &BareJID {
+ &self.bare_jid
+ }
+
+ pub fn to_bare(&self) -> BareJID {
+ self.bare_jid.clone()
+ }
+}
+
+impl BareJID {
+ pub fn new(localpart: Option<String>, domainpart: String) -> Self {
+ Self {
+ localpart,
+ domainpart,
+ }
+ }
+}
+
+impl TryFrom<JID> for BareJID {
+ type Error = JIDError;
+
+ fn try_from(value: JID) -> Result<Self, Self::Error> {
+ match value {
+ JID::Full(_full_jid) => Err(JIDError::ContainsResourcepart),
+ JID::Bare(bare_jid) => Ok(bare_jid),
+ }
+ }
+}
+
impl JID {
pub fn new(
localpart: Option<String>,
domainpart: String,
resourcepart: Option<String>,
) -> Self {
- Self {
- localpart,
- domainpart: domainpart.parse().unwrap(),
- resourcepart,
+ if let Some(resourcepart) = resourcepart {
+ Self::Full(FullJID::new(localpart, domainpart, resourcepart))
+ } else {
+ Self::Bare(BareJID::new(localpart, domainpart))
}
}
- pub fn as_bare(&self) -> Self {
- Self {
- localpart: self.localpart.clone(),
- domainpart: self.domainpart.clone(),
- resourcepart: None,
+ pub fn as_bare(&self) -> &BareJID {
+ match self {
+ JID::Full(full_jid) => full_jid.as_bare(),
+ JID::Bare(bare_jid) => &bare_jid,
}
}
- pub fn as_full(&self) -> Result<&Self, JIDError> {
- if let Some(_) = self.resourcepart {
- Ok(&self)
- } else {
- Err(JIDError::NoResourcePart)
+ pub fn to_bare(&self) -> BareJID {
+ match self {
+ JID::Full(full_jid) => full_jid.to_bare(),
+ JID::Bare(bare_jid) => bare_jid.clone(),
+ }
+ }
+
+ pub fn as_full(&self) -> Result<&FullJID, JIDError> {
+ match self {
+ JID::Full(full_jid) => Ok(full_jid),
+ JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart),
+ }
+ }
+}
+
+impl TryFrom<JID> for FullJID {
+ type Error = JIDError;
+
+ fn try_from(value: JID) -> Result<Self, Self::Error> {
+ match value {
+ JID::Full(full_jid) => Ok(full_jid),
+ JID::Bare(_bare_jid) => Err(JIDError::NoResourcePart),
}
}
}
+impl FromStr for BareJID {
+ type Err = JIDError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(JID::from_str(s)?.try_into()?)
+ }
+}
+
+impl FromStr for FullJID {
+ type Err = JIDError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ Ok(JID::from_str(s)?.try_into()?)
+ }
+}
+
impl FromStr for JID {
type Err = ParseError;
@@ -192,6 +342,12 @@ impl FromStr for JID {
}
}
+impl From<ParseError> for JIDError {
+ fn from(value: ParseError) -> Self {
+ JIDError::ParseError(value)
+ }
+}
+
impl TryFrom<String> for JID {
type Error = ParseError;
@@ -256,3 +412,32 @@ mod tests {
)
}
}
+
+// #[cfg(feature = "sqlx")]
+// impl sqlx::Type<Sqlite> for JID {
+// fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
+// <&str as sqlx::Type<Sqlite>>::type_info()
+// }
+// }
+
+// #[cfg(feature = "sqlx")]
+// impl sqlx::Decode<'_, Sqlite> for JID {
+// fn decode(
+// value: <Sqlite as sqlx::Database>::ValueRef<'_>,
+// ) -> Result<Self, sqlx::error::BoxDynError> {
+// let value = <&str as sqlx::Decode<Sqlite>>::decode(value)?;
+
+// Ok(value.parse()?)
+// }
+// }
+
+// #[cfg(feature = "sqlx")]
+// impl sqlx::Encode<'_, Sqlite> for JID {
+// fn encode_by_ref(
+// &self,
+// buf: &mut <Sqlite as sqlx::Database>::ArgumentBuffer<'_>,
+// ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
+// let jid = self.to_string();
+// <String as sqlx::Encode<Sqlite>>::encode(jid, buf)
+// }
+// }