diff options
author | 2024-11-23 22:39:44 +0000 | |
---|---|---|
committer | 2024-11-23 22:39:44 +0000 | |
commit | 40024d2dadba9e70edb2f3448204565ce3f68ab7 (patch) | |
tree | 3f08b61debf936c513f300c845d8a1cb29edd7c8 | |
parent | 9f2546f6dadd916b0e7fc5be51e92d682ef2487b (diff) | |
download | luz-40024d2dadba9e70edb2f3448204565ce3f68ab7.tar.gz luz-40024d2dadba9e70edb2f3448204565ce3f68ab7.tar.bz2 luz-40024d2dadba9e70edb2f3448204565ce3f68ab7.zip |
switch to using peanuts for xml
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | src/connection.rs | 12 | ||||
-rw-r--r-- | src/error.rs | 25 | ||||
-rw-r--r-- | src/jabber.rs | 75 | ||||
-rw-r--r-- | src/jid.rs | 28 | ||||
-rw-r--r-- | src/lib.rs | 4 | ||||
-rw-r--r-- | src/stanza/mod.rs | 8 | ||||
-rw-r--r-- | src/stanza/stream.rs | 169 |
8 files changed, 184 insertions, 141 deletions
@@ -11,16 +11,14 @@ async-recursion = "1.0.4" async-trait = "0.1.68" lazy_static = "1.4.0" nanoid = "0.4.0" -quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] } # TODO: remove unneeded features rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } -serde = "1.0.180" -serde_with = "3.4.0" tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" tracing = "0.1.40" trust-dns-resolver = "0.22.0" try_map = "0.3.1" +peanuts = { version = "0.1.0", path = "../peanuts" } [dev-dependencies] test-log = { version = "0.2", features = ["trace"] } diff --git a/src/connection.rs b/src/connection.rs index b42711e..89f382f 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -8,8 +8,8 @@ use tokio_native_tls::native_tls::TlsConnector; use tokio_native_tls::TlsStream; use tracing::{debug, info, instrument, trace}; +use crate::Error; use crate::Jabber; -use crate::JabberError; use crate::Result; pub type Tls = TlsStream<TcpStream>; @@ -75,7 +75,7 @@ impl Connection { } } } - Err(JabberError::Connection) + Err(Error::Connection) } #[instrument] @@ -154,19 +154,19 @@ impl Connection { pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { let socket = TcpStream::connect(socket_addr) .await - .map_err(|_| JabberError::Connection)?; - let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?; + .map_err(|_| Error::Connection)?; + let connector = TlsConnector::new().map_err(|_| Error::Connection)?; tokio_native_tls::TlsConnector::from(connector) .connect(domain_name, socket) .await - .map_err(|_| JabberError::Connection) + .map_err(|_| Error::Connection) } #[instrument] pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { TcpStream::connect(socket_addr) .await - .map_err(|_| JabberError::Connection) + .map_err(|_| Error::Connection) } } diff --git a/src/error.rs b/src/error.rs index b12914c..c7c867c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,11 @@ use std::str::Utf8Error; -use quick_xml::events::attributes::AttrError; use rsasl::mechname::MechanismNameError; use crate::jid::ParseError; #[derive(Debug)] -pub enum JabberError { +pub enum Error { Connection, BadStream, StartTlsUnavailable, @@ -23,7 +22,7 @@ pub enum JabberError { UnexpectedEnd, UnexpectedElement, UnexpectedText, - XML(quick_xml::Error), + XML(peanuts::Error), SASL(SASLError), JID(ParseError), } @@ -36,43 +35,37 @@ pub enum SASLError { NoSuccess, } -impl From<rsasl::prelude::SASLError> for JabberError { +impl From<rsasl::prelude::SASLError> for Error { fn from(e: rsasl::prelude::SASLError) -> Self { Self::SASL(SASLError::SASL(e)) } } -impl From<MechanismNameError> for JabberError { +impl From<MechanismNameError> for Error { fn from(e: MechanismNameError) -> Self { Self::SASL(SASLError::MechanismName(e)) } } -impl From<SASLError> for JabberError { +impl From<SASLError> for Error { fn from(e: SASLError) -> Self { Self::SASL(e) } } -impl From<Utf8Error> for JabberError { +impl From<Utf8Error> for Error { fn from(_e: Utf8Error) -> Self { Self::Utf8Decode } } -impl From<quick_xml::Error> for JabberError { - fn from(e: quick_xml::Error) -> Self { +impl From<peanuts::Error> for Error { + fn from(e: peanuts::Error) -> Self { Self::XML(e) } } -impl From<AttrError> for JabberError { - fn from(e: AttrError) -> Self { - Self::XML(e.into()) - } -} - -impl From<ParseError> for JabberError { +impl From<ParseError> for Error { fn from(e: ParseError) -> Self { Self::JID(e) } diff --git a/src/jabber.rs b/src/jabber.rs index 1436bfa..afe840b 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,16 +1,15 @@ use std::str; use std::sync::Arc; -use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; +use peanuts::{Reader, Writer}; use rsasl::prelude::SASLConfig; -use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tracing::{debug, info, trace}; use crate::connection::{Tls, Unencrypted}; -use crate::error::JabberError; +use crate::error::Error; use crate::stanza::stream::Stream; -use crate::stanza::DECLARATION; +use crate::stanza::XML_VERSION; use crate::Result; use crate::JID; @@ -18,8 +17,8 @@ pub struct Jabber<S> where S: AsyncRead + AsyncWrite + Unpin, { - reader: NsReader<BufReader<ReadHalf<S>>>, - writer: WriteHalf<S>, + reader: Reader<ReadHalf<S>>, + writer: Writer<WriteHalf<S>>, jid: Option<JID>, auth: Option<Arc<SASLConfig>>, server: String, @@ -36,7 +35,8 @@ where auth: Option<Arc<SASLConfig>>, server: String, ) -> Self { - let reader = NsReader::from_reader(BufReader::new(reader)); + let reader = Reader::new(reader); + let writer = Writer::new(writer); Self { reader, writer, @@ -49,7 +49,7 @@ where impl<S> Jabber<S> where - S: AsyncRead + AsyncWrite + Unpin, + S: AsyncRead + AsyncWrite + Unpin + Send, { // pub async fn negotiate(self) -> Result<Jabber<S>> {} @@ -57,65 +57,26 @@ where // client to server // declaration - let mut xmlwriter = Writer::new(&mut self.writer); - xmlwriter.write_event_async(DECLARATION.clone()).await?; + self.writer.write_declaration(XML_VERSION).await?; // opening stream element - let server = &self.server.to_owned().try_into()?; - let stream_element = Stream::new_client(None, server, None, "en"); + let server = self.server.clone().try_into()?; + let stream = Stream::new_client(None, server, None, "en".to_string()); // TODO: nicer function to serialize to xml writer - let mut buffer = String::new(); - let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name"); - stream_element.serialize(ser).unwrap(); - trace!("sent: {}", buffer); - self.writer.write_all(buffer.as_bytes()).await.unwrap(); + self.writer.write_start(&stream).await?; // server to client // may or may not send a declaration - let mut buf = Vec::new(); - let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; - trace!("received: {:?}", first_event); - match first_event { - (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { - if let Ok(version) = e.version() { - if version.as_ref() == b"1.0" { - first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; - trace!("received: {:?}", first_event); - } else { - // todo: error - todo!() - } - } else { - first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; - trace!("received: {:?}", first_event); - } - } - _ => (), - } + let decl = self.reader.read_prolog().await?; // receive stream element and validate - match first_event { - (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { - if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { - e.attributes().try_for_each(|attr| -> Result<()> { - let attr = attr?; - match attr.key.into_inner() { - b"from" => { - self.server = str::from_utf8(&attr.value)?.to_owned(); - Ok(()) - } - _ => Ok(()), - } - }); - return Ok(()); - } else { - return Err(JabberError::BadStream); - } - } - // TODO: errors for incorrect namespace - _ => Err(JabberError::BadStream), + let stream: Stream = self.reader.read_start().await?; + if let Some(from) = stream.from { + self.server = from.to_string() } + + Ok(()) } } @@ -1,7 +1,5 @@ use std::str::FromStr; -use serde::Serialize; - #[derive(PartialEq, Debug, Clone)] pub struct JID { // TODO: validate localpart (length, char] @@ -10,15 +8,6 @@ pub struct JID { pub resourcepart: Option<String>, } -impl Serialize for JID { - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - pub enum JIDError { NoResourcePart, ParseError(ParseError), @@ -27,7 +16,16 @@ pub enum JIDError { #[derive(Debug)] pub enum ParseError { Empty, - Malformed, + Malformed(String), +} + +impl From<ParseError> for peanuts::Error { + fn from(e: ParseError) -> Self { + match e { + ParseError::Empty => peanuts::Error::DeserializeError("".to_string()), + ParseError::Malformed(e) => peanuts::Error::DeserializeError(e), + } + } } impl JID { @@ -76,7 +74,7 @@ impl FromStr for JID { split[0].to_string(), Some(split[1].to_string()), )), - _ => Err(ParseError::Malformed), + _ => Err(ParseError::Malformed(s.to_string())), } } 2 => { @@ -92,10 +90,10 @@ impl FromStr for JID { split2[0].to_string(), Some(split2[1].to_string()), )), - _ => Err(ParseError::Malformed), + _ => Err(ParseError::Malformed(s.to_string())), } } - _ => Err(ParseError::Malformed), + _ => Err(ParseError::Malformed(s.to_string())), } } } @@ -12,11 +12,11 @@ pub mod stanza; extern crate lazy_static; pub use connection::Connection; -pub use error::JabberError; +pub use error::Error; pub use jabber::Jabber; pub use jid::JID; -pub type Result<T> = std::result::Result<T, JabberError>; +pub type Result<T> = std::result::Result<T, Error>; pub async fn login<J: TryInto<JID>, P: AsRef<str>>(jid: J, password: P) -> Result<Connection> { todo!() diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index e4f080f..4f1ce48 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1,3 +1,5 @@ +use peanuts::declaration::VersionInfo; + pub mod bind; pub mod iq; pub mod message; @@ -6,8 +8,4 @@ pub mod sasl; pub mod starttls; pub mod stream; -use quick_xml::events::{BytesDecl, Event}; - -lazy_static! { - pub static ref DECLARATION: Event<'static> = Event::Decl(BytesDecl::new("1.0", None, None)); -} +pub static XML_VERSION: VersionInfo = VersionInfo::One; diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 9a21373..ac4badc 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,37 +1,141 @@ -use serde::Serialize; +use std::collections::{HashMap, HashSet}; -use crate::JID; +use peanuts::element::{Content, FromElement, IntoElement, NamespaceDeclaration}; +use peanuts::XML_NS; +use peanuts::{element::Name, Element}; -pub static XMLNS: &str = "http://etherx.jabber.org/streams"; -pub static XMLNS_CLIENT: &str = "jabber:client"; +use crate::{Error, JID}; + +pub const XMLNS: &str = "http://etherx.jabber.org/streams"; +pub const XMLNS_CLIENT: &str = "jabber:client"; // MUST be qualified by stream namespace -#[derive(Serialize)] -pub struct Stream<'s> { - #[serde(rename = "@from")] - from: Option<&'s JID>, - #[serde(rename = "@to")] - to: Option<&'s JID>, - #[serde(rename = "@id")] - id: Option<&'s str>, - #[serde(rename = "@version")] - version: Option<&'s str>, +// #[derive(XmlSerialize, XmlDeserialize)] +// #[peanuts(xmlns = XMLNS)] +pub struct Stream { + pub from: Option<JID>, + to: Option<JID>, + id: Option<String>, + version: Option<String>, // TODO: lang enum - #[serde(rename = "@lang")] - lang: Option<&'s str>, - #[serde(rename = "@xmlns")] - xmlns: &'s str, - #[serde(rename = "@xmlns:stream")] - xmlns_stream: &'s str, + lang: Option<String>, + // #[peanuts(content)] + // content: Message, +} + +impl FromElement for Stream { + fn from_element(element: Element) -> peanuts::Result<Self> { + let Name { + namespace, + local_name, + } = element.name; + if namespace.as_deref() == Some(XMLNS) && &local_name == "stream" { + let (mut from, mut to, mut id, mut version, mut lang) = (None, None, None, None, None); + for (name, value) in element.attributes { + match (name.namespace.as_deref(), name.local_name.as_str()) { + (None, "from") => from = Some(value.try_into()?), + (None, "to") => to = Some(value.try_into()?), + (None, "id") => id = Some(value), + (None, "version") => version = Some(value), + (Some(XML_NS), "lang") => lang = Some(value), + _ => return Err(peanuts::Error::UnexpectedAttribute(name)), + } + } + return Ok(Stream { + from, + to, + id, + version, + lang, + }); + } else { + return Err(peanuts::Error::IncorrectName(Name { + namespace, + local_name, + })); + } + } +} + +impl IntoElement for Stream { + fn into_element(&self) -> Element { + let mut namespace_declarations = HashSet::new(); + namespace_declarations.insert(NamespaceDeclaration { + prefix: Some("stream".to_string()), + namespace: XMLNS.to_string(), + }); + namespace_declarations.insert(NamespaceDeclaration { + prefix: None, + // TODO: don't default to client + namespace: XMLNS_CLIENT.to_string(), + }); + + let mut attributes = HashMap::new(); + self.from.as_ref().map(|from| { + attributes.insert( + Name { + namespace: None, + local_name: "from".to_string(), + }, + from.to_string(), + ); + }); + self.to.as_ref().map(|to| { + attributes.insert( + Name { + namespace: None, + local_name: "to".to_string(), + }, + to.to_string(), + ); + }); + self.id.as_ref().map(|id| { + attributes.insert( + Name { + namespace: None, + local_name: "version".to_string(), + }, + id.clone(), + ); + }); + self.version.as_ref().map(|version| { + attributes.insert( + Name { + namespace: None, + local_name: "version".to_string(), + }, + version.clone(), + ); + }); + self.lang.as_ref().map(|lang| { + attributes.insert( + Name { + namespace: Some(XML_NS.to_string()), + local_name: "lang".to_string(), + }, + lang.to_string(), + ); + }); + + Element { + name: Name { + namespace: Some(XMLNS.to_string()), + local_name: "stream".to_string(), + }, + namespace_declarations, + attributes, + content: Vec::new(), + } + } } -impl<'s> Stream<'s> { +impl<'s> Stream { pub fn new( - from: Option<&'s JID>, - to: Option<&'s JID>, - id: Option<&'s str>, - version: Option<&'s str>, - lang: Option<&'s str>, + from: Option<JID>, + to: Option<JID>, + id: Option<String>, + version: Option<String>, + lang: Option<String>, ) -> Self { Self { from, @@ -39,27 +143,18 @@ impl<'s> Stream<'s> { id, version, lang, - xmlns: XMLNS_CLIENT, - xmlns_stream: XMLNS, } } /// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute. /// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection. - pub fn new_client( - from: Option<&'s JID>, - to: &'s JID, - id: Option<&'s str>, - lang: &'s str, - ) -> Self { + pub fn new_client(from: Option<JID>, to: JID, id: Option<String>, lang: String) -> Self { Self { from, to: Some(to), id, - version: Some("1.0"), + version: Some("1.0".to_string()), lang: Some(lang), - xmlns: XMLNS_CLIENT, - xmlns_stream: XMLNS, } } } |