aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-11-23 22:39:44 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-11-23 22:39:44 +0000
commit40024d2dadba9e70edb2f3448204565ce3f68ab7 (patch)
tree3f08b61debf936c513f300c845d8a1cb29edd7c8
parent9f2546f6dadd916b0e7fc5be51e92d682ef2487b (diff)
downloadluz-40024d2dadba9e70edb2f3448204565ce3f68ab7.tar.gz
luz-40024d2dadba9e70edb2f3448204565ce3f68ab7.tar.bz2
luz-40024d2dadba9e70edb2f3448204565ce3f68ab7.zip
switch to using peanuts for xml
-rw-r--r--Cargo.toml4
-rw-r--r--src/connection.rs12
-rw-r--r--src/error.rs25
-rw-r--r--src/jabber.rs75
-rw-r--r--src/jid.rs28
-rw-r--r--src/lib.rs4
-rw-r--r--src/stanza/mod.rs8
-rw-r--r--src/stanza/stream.rs169
8 files changed, 184 insertions, 141 deletions
diff --git a/Cargo.toml b/Cargo.toml
index ea47616..f136e90 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,16 +11,14 @@ async-recursion = "1.0.4"
async-trait = "0.1.68"
lazy_static = "1.4.0"
nanoid = "0.4.0"
-quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] }
# TODO: remove unneeded features
rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] }
-serde = "1.0.180"
-serde_with = "3.4.0"
tokio = { version = "1.28", features = ["full"] }
tokio-native-tls = "0.3.1"
tracing = "0.1.40"
trust-dns-resolver = "0.22.0"
try_map = "0.3.1"
+peanuts = { version = "0.1.0", path = "../peanuts" }
[dev-dependencies]
test-log = { version = "0.2", features = ["trace"] }
diff --git a/src/connection.rs b/src/connection.rs
index b42711e..89f382f 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -8,8 +8,8 @@ use tokio_native_tls::native_tls::TlsConnector;
use tokio_native_tls::TlsStream;
use tracing::{debug, info, instrument, trace};
+use crate::Error;
use crate::Jabber;
-use crate::JabberError;
use crate::Result;
pub type Tls = TlsStream<TcpStream>;
@@ -75,7 +75,7 @@ impl Connection {
}
}
}
- Err(JabberError::Connection)
+ Err(Error::Connection)
}
#[instrument]
@@ -154,19 +154,19 @@ impl Connection {
pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
let socket = TcpStream::connect(socket_addr)
.await
- .map_err(|_| JabberError::Connection)?;
- let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?;
+ .map_err(|_| Error::Connection)?;
+ let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
tokio_native_tls::TlsConnector::from(connector)
.connect(domain_name, socket)
.await
- .map_err(|_| JabberError::Connection)
+ .map_err(|_| Error::Connection)
}
#[instrument]
pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
TcpStream::connect(socket_addr)
.await
- .map_err(|_| JabberError::Connection)
+ .map_err(|_| Error::Connection)
}
}
diff --git a/src/error.rs b/src/error.rs
index b12914c..c7c867c 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,12 +1,11 @@
use std::str::Utf8Error;
-use quick_xml::events::attributes::AttrError;
use rsasl::mechname::MechanismNameError;
use crate::jid::ParseError;
#[derive(Debug)]
-pub enum JabberError {
+pub enum Error {
Connection,
BadStream,
StartTlsUnavailable,
@@ -23,7 +22,7 @@ pub enum JabberError {
UnexpectedEnd,
UnexpectedElement,
UnexpectedText,
- XML(quick_xml::Error),
+ XML(peanuts::Error),
SASL(SASLError),
JID(ParseError),
}
@@ -36,43 +35,37 @@ pub enum SASLError {
NoSuccess,
}
-impl From<rsasl::prelude::SASLError> for JabberError {
+impl From<rsasl::prelude::SASLError> for Error {
fn from(e: rsasl::prelude::SASLError) -> Self {
Self::SASL(SASLError::SASL(e))
}
}
-impl From<MechanismNameError> for JabberError {
+impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
}
}
-impl From<SASLError> for JabberError {
+impl From<SASLError> for Error {
fn from(e: SASLError) -> Self {
Self::SASL(e)
}
}
-impl From<Utf8Error> for JabberError {
+impl From<Utf8Error> for Error {
fn from(_e: Utf8Error) -> Self {
Self::Utf8Decode
}
}
-impl From<quick_xml::Error> for JabberError {
- fn from(e: quick_xml::Error) -> Self {
+impl From<peanuts::Error> for Error {
+ fn from(e: peanuts::Error) -> Self {
Self::XML(e)
}
}
-impl From<AttrError> for JabberError {
- fn from(e: AttrError) -> Self {
- Self::XML(e.into())
- }
-}
-
-impl From<ParseError> for JabberError {
+impl From<ParseError> for Error {
fn from(e: ParseError) -> Self {
Self::JID(e)
}
diff --git a/src/jabber.rs b/src/jabber.rs
index 1436bfa..afe840b 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -1,16 +1,15 @@
use std::str;
use std::sync::Arc;
-use quick_xml::{events::Event, se::Serializer, NsReader, Writer};
+use peanuts::{Reader, Writer};
use rsasl::prelude::SASLConfig;
-use serde::Serialize;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
use tracing::{debug, info, trace};
use crate::connection::{Tls, Unencrypted};
-use crate::error::JabberError;
+use crate::error::Error;
use crate::stanza::stream::Stream;
-use crate::stanza::DECLARATION;
+use crate::stanza::XML_VERSION;
use crate::Result;
use crate::JID;
@@ -18,8 +17,8 @@ pub struct Jabber<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
- reader: NsReader<BufReader<ReadHalf<S>>>,
- writer: WriteHalf<S>,
+ reader: Reader<ReadHalf<S>>,
+ writer: Writer<WriteHalf<S>>,
jid: Option<JID>,
auth: Option<Arc<SASLConfig>>,
server: String,
@@ -36,7 +35,8 @@ where
auth: Option<Arc<SASLConfig>>,
server: String,
) -> Self {
- let reader = NsReader::from_reader(BufReader::new(reader));
+ let reader = Reader::new(reader);
+ let writer = Writer::new(writer);
Self {
reader,
writer,
@@ -49,7 +49,7 @@ where
impl<S> Jabber<S>
where
- S: AsyncRead + AsyncWrite + Unpin,
+ S: AsyncRead + AsyncWrite + Unpin + Send,
{
// pub async fn negotiate(self) -> Result<Jabber<S>> {}
@@ -57,65 +57,26 @@ where
// client to server
// declaration
- let mut xmlwriter = Writer::new(&mut self.writer);
- xmlwriter.write_event_async(DECLARATION.clone()).await?;
+ self.writer.write_declaration(XML_VERSION).await?;
// opening stream element
- let server = &self.server.to_owned().try_into()?;
- let stream_element = Stream::new_client(None, server, None, "en");
+ let server = self.server.clone().try_into()?;
+ let stream = Stream::new_client(None, server, None, "en".to_string());
// TODO: nicer function to serialize to xml writer
- let mut buffer = String::new();
- let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name");
- stream_element.serialize(ser).unwrap();
- trace!("sent: {}", buffer);
- self.writer.write_all(buffer.as_bytes()).await.unwrap();
+ self.writer.write_start(&stream).await?;
// server to client
// may or may not send a declaration
- let mut buf = Vec::new();
- let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
- trace!("received: {:?}", first_event);
- match first_event {
- (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => {
- if let Ok(version) = e.version() {
- if version.as_ref() == b"1.0" {
- first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
- trace!("received: {:?}", first_event);
- } else {
- // todo: error
- todo!()
- }
- } else {
- first_event = self.reader.read_resolved_event_into_async(&mut buf).await?;
- trace!("received: {:?}", first_event);
- }
- }
- _ => (),
- }
+ let decl = self.reader.read_prolog().await?;
// receive stream element and validate
- match first_event {
- (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => {
- if ns.0 == crate::stanza::stream::XMLNS.as_bytes() {
- e.attributes().try_for_each(|attr| -> Result<()> {
- let attr = attr?;
- match attr.key.into_inner() {
- b"from" => {
- self.server = str::from_utf8(&attr.value)?.to_owned();
- Ok(())
- }
- _ => Ok(()),
- }
- });
- return Ok(());
- } else {
- return Err(JabberError::BadStream);
- }
- }
- // TODO: errors for incorrect namespace
- _ => Err(JabberError::BadStream),
+ let stream: Stream = self.reader.read_start().await?;
+ if let Some(from) = stream.from {
+ self.server = from.to_string()
}
+
+ Ok(())
}
}
diff --git a/src/jid.rs b/src/jid.rs
index 65738dc..233227a 100644
--- a/src/jid.rs
+++ b/src/jid.rs
@@ -1,7 +1,5 @@
use std::str::FromStr;
-use serde::Serialize;
-
#[derive(PartialEq, Debug, Clone)]
pub struct JID {
// TODO: validate localpart (length, char]
@@ -10,15 +8,6 @@ pub struct JID {
pub resourcepart: Option<String>,
}
-impl Serialize for JID {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: serde::Serializer,
- {
- serializer.serialize_str(&self.to_string())
- }
-}
-
pub enum JIDError {
NoResourcePart,
ParseError(ParseError),
@@ -27,7 +16,16 @@ pub enum JIDError {
#[derive(Debug)]
pub enum ParseError {
Empty,
- Malformed,
+ Malformed(String),
+}
+
+impl From<ParseError> for peanuts::Error {
+ fn from(e: ParseError) -> Self {
+ match e {
+ ParseError::Empty => peanuts::Error::DeserializeError("".to_string()),
+ ParseError::Malformed(e) => peanuts::Error::DeserializeError(e),
+ }
+ }
}
impl JID {
@@ -76,7 +74,7 @@ impl FromStr for JID {
split[0].to_string(),
Some(split[1].to_string()),
)),
- _ => Err(ParseError::Malformed),
+ _ => Err(ParseError::Malformed(s.to_string())),
}
}
2 => {
@@ -92,10 +90,10 @@ impl FromStr for JID {
split2[0].to_string(),
Some(split2[1].to_string()),
)),
- _ => Err(ParseError::Malformed),
+ _ => Err(ParseError::Malformed(s.to_string())),
}
}
- _ => Err(ParseError::Malformed),
+ _ => Err(ParseError::Malformed(s.to_string())),
}
}
}
diff --git a/src/lib.rs b/src/lib.rs
index a7f0494..306b0fd 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,11 +12,11 @@ pub mod stanza;
extern crate lazy_static;
pub use connection::Connection;
-pub use error::JabberError;
+pub use error::Error;
pub use jabber::Jabber;
pub use jid::JID;
-pub type Result<T> = std::result::Result<T, JabberError>;
+pub type Result<T> = std::result::Result<T, Error>;
pub async fn login<J: TryInto<JID>, P: AsRef<str>>(jid: J, password: P) -> Result<Connection> {
todo!()
diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs
index e4f080f..4f1ce48 100644
--- a/src/stanza/mod.rs
+++ b/src/stanza/mod.rs
@@ -1,3 +1,5 @@
+use peanuts::declaration::VersionInfo;
+
pub mod bind;
pub mod iq;
pub mod message;
@@ -6,8 +8,4 @@ pub mod sasl;
pub mod starttls;
pub mod stream;
-use quick_xml::events::{BytesDecl, Event};
-
-lazy_static! {
- pub static ref DECLARATION: Event<'static> = Event::Decl(BytesDecl::new("1.0", None, None));
-}
+pub static XML_VERSION: VersionInfo = VersionInfo::One;
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index 9a21373..ac4badc 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -1,37 +1,141 @@
-use serde::Serialize;
+use std::collections::{HashMap, HashSet};
-use crate::JID;
+use peanuts::element::{Content, FromElement, IntoElement, NamespaceDeclaration};
+use peanuts::XML_NS;
+use peanuts::{element::Name, Element};
-pub static XMLNS: &str = "http://etherx.jabber.org/streams";
-pub static XMLNS_CLIENT: &str = "jabber:client";
+use crate::{Error, JID};
+
+pub const XMLNS: &str = "http://etherx.jabber.org/streams";
+pub const XMLNS_CLIENT: &str = "jabber:client";
// MUST be qualified by stream namespace
-#[derive(Serialize)]
-pub struct Stream<'s> {
- #[serde(rename = "@from")]
- from: Option<&'s JID>,
- #[serde(rename = "@to")]
- to: Option<&'s JID>,
- #[serde(rename = "@id")]
- id: Option<&'s str>,
- #[serde(rename = "@version")]
- version: Option<&'s str>,
+// #[derive(XmlSerialize, XmlDeserialize)]
+// #[peanuts(xmlns = XMLNS)]
+pub struct Stream {
+ pub from: Option<JID>,
+ to: Option<JID>,
+ id: Option<String>,
+ version: Option<String>,
// TODO: lang enum
- #[serde(rename = "@lang")]
- lang: Option<&'s str>,
- #[serde(rename = "@xmlns")]
- xmlns: &'s str,
- #[serde(rename = "@xmlns:stream")]
- xmlns_stream: &'s str,
+ lang: Option<String>,
+ // #[peanuts(content)]
+ // content: Message,
+}
+
+impl FromElement for Stream {
+ fn from_element(element: Element) -> peanuts::Result<Self> {
+ let Name {
+ namespace,
+ local_name,
+ } = element.name;
+ if namespace.as_deref() == Some(XMLNS) && &local_name == "stream" {
+ let (mut from, mut to, mut id, mut version, mut lang) = (None, None, None, None, None);
+ for (name, value) in element.attributes {
+ match (name.namespace.as_deref(), name.local_name.as_str()) {
+ (None, "from") => from = Some(value.try_into()?),
+ (None, "to") => to = Some(value.try_into()?),
+ (None, "id") => id = Some(value),
+ (None, "version") => version = Some(value),
+ (Some(XML_NS), "lang") => lang = Some(value),
+ _ => return Err(peanuts::Error::UnexpectedAttribute(name)),
+ }
+ }
+ return Ok(Stream {
+ from,
+ to,
+ id,
+ version,
+ lang,
+ });
+ } else {
+ return Err(peanuts::Error::IncorrectName(Name {
+ namespace,
+ local_name,
+ }));
+ }
+ }
+}
+
+impl IntoElement for Stream {
+ fn into_element(&self) -> Element {
+ let mut namespace_declarations = HashSet::new();
+ namespace_declarations.insert(NamespaceDeclaration {
+ prefix: Some("stream".to_string()),
+ namespace: XMLNS.to_string(),
+ });
+ namespace_declarations.insert(NamespaceDeclaration {
+ prefix: None,
+ // TODO: don't default to client
+ namespace: XMLNS_CLIENT.to_string(),
+ });
+
+ let mut attributes = HashMap::new();
+ self.from.as_ref().map(|from| {
+ attributes.insert(
+ Name {
+ namespace: None,
+ local_name: "from".to_string(),
+ },
+ from.to_string(),
+ );
+ });
+ self.to.as_ref().map(|to| {
+ attributes.insert(
+ Name {
+ namespace: None,
+ local_name: "to".to_string(),
+ },
+ to.to_string(),
+ );
+ });
+ self.id.as_ref().map(|id| {
+ attributes.insert(
+ Name {
+ namespace: None,
+ local_name: "version".to_string(),
+ },
+ id.clone(),
+ );
+ });
+ self.version.as_ref().map(|version| {
+ attributes.insert(
+ Name {
+ namespace: None,
+ local_name: "version".to_string(),
+ },
+ version.clone(),
+ );
+ });
+ self.lang.as_ref().map(|lang| {
+ attributes.insert(
+ Name {
+ namespace: Some(XML_NS.to_string()),
+ local_name: "lang".to_string(),
+ },
+ lang.to_string(),
+ );
+ });
+
+ Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "stream".to_string(),
+ },
+ namespace_declarations,
+ attributes,
+ content: Vec::new(),
+ }
+ }
}
-impl<'s> Stream<'s> {
+impl<'s> Stream {
pub fn new(
- from: Option<&'s JID>,
- to: Option<&'s JID>,
- id: Option<&'s str>,
- version: Option<&'s str>,
- lang: Option<&'s str>,
+ from: Option<JID>,
+ to: Option<JID>,
+ id: Option<String>,
+ version: Option<String>,
+ lang: Option<String>,
) -> Self {
Self {
from,
@@ -39,27 +143,18 @@ impl<'s> Stream<'s> {
id,
version,
lang,
- xmlns: XMLNS_CLIENT,
- xmlns_stream: XMLNS,
}
}
/// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute.
/// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection.
- pub fn new_client(
- from: Option<&'s JID>,
- to: &'s JID,
- id: Option<&'s str>,
- lang: &'s str,
- ) -> Self {
+ pub fn new_client(from: Option<JID>, to: JID, id: Option<String>, lang: String) -> Self {
Self {
from,
to: Some(to),
id,
- version: Some("1.0"),
+ version: Some("1.0".to_string()),
lang: Some(lang),
- xmlns: XMLNS_CLIENT,
- xmlns_stream: XMLNS,
}
}
}