diff options
Diffstat (limited to '')
| -rw-r--r-- | Cargo.toml | 4 | ||||
| -rw-r--r-- | src/connection.rs | 12 | ||||
| -rw-r--r-- | src/error.rs | 25 | ||||
| -rw-r--r-- | src/jabber.rs | 75 | ||||
| -rw-r--r-- | src/jid.rs | 28 | ||||
| -rw-r--r-- | src/lib.rs | 4 | ||||
| -rw-r--r-- | src/stanza/mod.rs | 8 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 169 | 
8 files changed, 184 insertions, 141 deletions
| @@ -11,16 +11,14 @@ async-recursion = "1.0.4"  async-trait = "0.1.68"  lazy_static = "1.4.0"  nanoid = "0.4.0" -quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async-tokio", "serialize"] }  # TODO: remove unneeded features  rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } -serde = "1.0.180" -serde_with = "3.4.0"  tokio = { version = "1.28", features = ["full"] }  tokio-native-tls = "0.3.1"  tracing = "0.1.40"  trust-dns-resolver = "0.22.0"  try_map = "0.3.1" +peanuts = { version = "0.1.0", path = "../peanuts" }  [dev-dependencies]  test-log = { version = "0.2", features = ["trace"] } diff --git a/src/connection.rs b/src/connection.rs index b42711e..89f382f 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -8,8 +8,8 @@ use tokio_native_tls::native_tls::TlsConnector;  use tokio_native_tls::TlsStream;  use tracing::{debug, info, instrument, trace}; +use crate::Error;  use crate::Jabber; -use crate::JabberError;  use crate::Result;  pub type Tls = TlsStream<TcpStream>; @@ -75,7 +75,7 @@ impl Connection {                  }              }          } -        Err(JabberError::Connection) +        Err(Error::Connection)      }      #[instrument] @@ -154,19 +154,19 @@ impl Connection {      pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {          let socket = TcpStream::connect(socket_addr)              .await -            .map_err(|_| JabberError::Connection)?; -        let connector = TlsConnector::new().map_err(|_| JabberError::Connection)?; +            .map_err(|_| Error::Connection)?; +        let connector = TlsConnector::new().map_err(|_| Error::Connection)?;          tokio_native_tls::TlsConnector::from(connector)              .connect(domain_name, socket)              .await -            .map_err(|_| JabberError::Connection) +            .map_err(|_| Error::Connection)      }      #[instrument]      pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {          TcpStream::connect(socket_addr)              .await -            .map_err(|_| JabberError::Connection) +            .map_err(|_| Error::Connection)      }  } diff --git a/src/error.rs b/src/error.rs index b12914c..c7c867c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,11 @@  use std::str::Utf8Error; -use quick_xml::events::attributes::AttrError;  use rsasl::mechname::MechanismNameError;  use crate::jid::ParseError;  #[derive(Debug)] -pub enum JabberError { +pub enum Error {      Connection,      BadStream,      StartTlsUnavailable, @@ -23,7 +22,7 @@ pub enum JabberError {      UnexpectedEnd,      UnexpectedElement,      UnexpectedText, -    XML(quick_xml::Error), +    XML(peanuts::Error),      SASL(SASLError),      JID(ParseError),  } @@ -36,43 +35,37 @@ pub enum SASLError {      NoSuccess,  } -impl From<rsasl::prelude::SASLError> for JabberError { +impl From<rsasl::prelude::SASLError> for Error {      fn from(e: rsasl::prelude::SASLError) -> Self {          Self::SASL(SASLError::SASL(e))      }  } -impl From<MechanismNameError> for JabberError { +impl From<MechanismNameError> for Error {      fn from(e: MechanismNameError) -> Self {          Self::SASL(SASLError::MechanismName(e))      }  } -impl From<SASLError> for JabberError { +impl From<SASLError> for Error {      fn from(e: SASLError) -> Self {          Self::SASL(e)      }  } -impl From<Utf8Error> for JabberError { +impl From<Utf8Error> for Error {      fn from(_e: Utf8Error) -> Self {          Self::Utf8Decode      }  } -impl From<quick_xml::Error> for JabberError { -    fn from(e: quick_xml::Error) -> Self { +impl From<peanuts::Error> for Error { +    fn from(e: peanuts::Error) -> Self {          Self::XML(e)      }  } -impl From<AttrError> for JabberError { -    fn from(e: AttrError) -> Self { -        Self::XML(e.into()) -    } -} - -impl From<ParseError> for JabberError { +impl From<ParseError> for Error {      fn from(e: ParseError) -> Self {          Self::JID(e)      } diff --git a/src/jabber.rs b/src/jabber.rs index 1436bfa..afe840b 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,16 +1,15 @@  use std::str;  use std::sync::Arc; -use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; +use peanuts::{Reader, Writer};  use rsasl::prelude::SASLConfig; -use serde::Serialize;  use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};  use tracing::{debug, info, trace};  use crate::connection::{Tls, Unencrypted}; -use crate::error::JabberError; +use crate::error::Error;  use crate::stanza::stream::Stream; -use crate::stanza::DECLARATION; +use crate::stanza::XML_VERSION;  use crate::Result;  use crate::JID; @@ -18,8 +17,8 @@ pub struct Jabber<S>  where      S: AsyncRead + AsyncWrite + Unpin,  { -    reader: NsReader<BufReader<ReadHalf<S>>>, -    writer: WriteHalf<S>, +    reader: Reader<ReadHalf<S>>, +    writer: Writer<WriteHalf<S>>,      jid: Option<JID>,      auth: Option<Arc<SASLConfig>>,      server: String, @@ -36,7 +35,8 @@ where          auth: Option<Arc<SASLConfig>>,          server: String,      ) -> Self { -        let reader = NsReader::from_reader(BufReader::new(reader)); +        let reader = Reader::new(reader); +        let writer = Writer::new(writer);          Self {              reader,              writer, @@ -49,7 +49,7 @@ where  impl<S> Jabber<S>  where -    S: AsyncRead + AsyncWrite + Unpin, +    S: AsyncRead + AsyncWrite + Unpin + Send,  {      // pub async fn negotiate(self) -> Result<Jabber<S>> {} @@ -57,65 +57,26 @@ where          // client to server          // declaration -        let mut xmlwriter = Writer::new(&mut self.writer); -        xmlwriter.write_event_async(DECLARATION.clone()).await?; +        self.writer.write_declaration(XML_VERSION).await?;          // opening stream element -        let server = &self.server.to_owned().try_into()?; -        let stream_element = Stream::new_client(None, server, None, "en"); +        let server = self.server.clone().try_into()?; +        let stream = Stream::new_client(None, server, None, "en".to_string());          // TODO: nicer function to serialize to xml writer -        let mut buffer = String::new(); -        let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name"); -        stream_element.serialize(ser).unwrap(); -        trace!("sent: {}", buffer); -        self.writer.write_all(buffer.as_bytes()).await.unwrap(); +        self.writer.write_start(&stream).await?;          // server to client          // may or may not send a declaration -        let mut buf = Vec::new(); -        let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; -        trace!("received: {:?}", first_event); -        match first_event { -            (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { -                if let Ok(version) = e.version() { -                    if version.as_ref() == b"1.0" { -                        first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; -                        trace!("received: {:?}", first_event); -                    } else { -                        // todo: error -                        todo!() -                    } -                } else { -                    first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; -                    trace!("received: {:?}", first_event); -                } -            } -            _ => (), -        } +        let decl = self.reader.read_prolog().await?;          // receive stream element and validate -        match first_event { -            (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { -                if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { -                    e.attributes().try_for_each(|attr| -> Result<()> { -                        let attr = attr?; -                        match attr.key.into_inner() { -                            b"from" => { -                                self.server = str::from_utf8(&attr.value)?.to_owned(); -                                Ok(()) -                            } -                            _ => Ok(()), -                        } -                    }); -                    return Ok(()); -                } else { -                    return Err(JabberError::BadStream); -                } -            } -            // TODO: errors for incorrect namespace -            _ => Err(JabberError::BadStream), +        let stream: Stream = self.reader.read_start().await?; +        if let Some(from) = stream.from { +            self.server = from.to_string()          } + +        Ok(())      }  } @@ -1,7 +1,5 @@  use std::str::FromStr; -use serde::Serialize; -  #[derive(PartialEq, Debug, Clone)]  pub struct JID {      // TODO: validate localpart (length, char] @@ -10,15 +8,6 @@ pub struct JID {      pub resourcepart: Option<String>,  } -impl Serialize for JID { -    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> -    where -        S: serde::Serializer, -    { -        serializer.serialize_str(&self.to_string()) -    } -} -  pub enum JIDError {      NoResourcePart,      ParseError(ParseError), @@ -27,7 +16,16 @@ pub enum JIDError {  #[derive(Debug)]  pub enum ParseError {      Empty, -    Malformed, +    Malformed(String), +} + +impl From<ParseError> for peanuts::Error { +    fn from(e: ParseError) -> Self { +        match e { +            ParseError::Empty => peanuts::Error::DeserializeError("".to_string()), +            ParseError::Malformed(e) => peanuts::Error::DeserializeError(e), +        } +    }  }  impl JID { @@ -76,7 +74,7 @@ impl FromStr for JID {                          split[0].to_string(),                          Some(split[1].to_string()),                      )), -                    _ => Err(ParseError::Malformed), +                    _ => Err(ParseError::Malformed(s.to_string())),                  }              }              2 => { @@ -92,10 +90,10 @@ impl FromStr for JID {                          split2[0].to_string(),                          Some(split2[1].to_string()),                      )), -                    _ => Err(ParseError::Malformed), +                    _ => Err(ParseError::Malformed(s.to_string())),                  }              } -            _ => Err(ParseError::Malformed), +            _ => Err(ParseError::Malformed(s.to_string())),          }      }  } @@ -12,11 +12,11 @@ pub mod stanza;  extern crate lazy_static;  pub use connection::Connection; -pub use error::JabberError; +pub use error::Error;  pub use jabber::Jabber;  pub use jid::JID; -pub type Result<T> = std::result::Result<T, JabberError>; +pub type Result<T> = std::result::Result<T, Error>;  pub async fn login<J: TryInto<JID>, P: AsRef<str>>(jid: J, password: P) -> Result<Connection> {      todo!() diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs index e4f080f..4f1ce48 100644 --- a/src/stanza/mod.rs +++ b/src/stanza/mod.rs @@ -1,3 +1,5 @@ +use peanuts::declaration::VersionInfo; +  pub mod bind;  pub mod iq;  pub mod message; @@ -6,8 +8,4 @@ pub mod sasl;  pub mod starttls;  pub mod stream; -use quick_xml::events::{BytesDecl, Event}; - -lazy_static! { -    pub static ref DECLARATION: Event<'static> = Event::Decl(BytesDecl::new("1.0", None, None)); -} +pub static XML_VERSION: VersionInfo = VersionInfo::One; diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index 9a21373..ac4badc 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -1,37 +1,141 @@ -use serde::Serialize; +use std::collections::{HashMap, HashSet}; -use crate::JID; +use peanuts::element::{Content, FromElement, IntoElement, NamespaceDeclaration}; +use peanuts::XML_NS; +use peanuts::{element::Name, Element}; -pub static XMLNS: &str = "http://etherx.jabber.org/streams"; -pub static XMLNS_CLIENT: &str = "jabber:client"; +use crate::{Error, JID}; + +pub const XMLNS: &str = "http://etherx.jabber.org/streams"; +pub const XMLNS_CLIENT: &str = "jabber:client";  // MUST be qualified by stream namespace -#[derive(Serialize)] -pub struct Stream<'s> { -    #[serde(rename = "@from")] -    from: Option<&'s JID>, -    #[serde(rename = "@to")] -    to: Option<&'s JID>, -    #[serde(rename = "@id")] -    id: Option<&'s str>, -    #[serde(rename = "@version")] -    version: Option<&'s str>, +// #[derive(XmlSerialize, XmlDeserialize)] +// #[peanuts(xmlns = XMLNS)] +pub struct Stream { +    pub from: Option<JID>, +    to: Option<JID>, +    id: Option<String>, +    version: Option<String>,      // TODO: lang enum -    #[serde(rename = "@lang")] -    lang: Option<&'s str>, -    #[serde(rename = "@xmlns")] -    xmlns: &'s str, -    #[serde(rename = "@xmlns:stream")] -    xmlns_stream: &'s str, +    lang: Option<String>, +    // #[peanuts(content)] +    // content: Message, +} + +impl FromElement for Stream { +    fn from_element(element: Element) -> peanuts::Result<Self> { +        let Name { +            namespace, +            local_name, +        } = element.name; +        if namespace.as_deref() == Some(XMLNS) && &local_name == "stream" { +            let (mut from, mut to, mut id, mut version, mut lang) = (None, None, None, None, None); +            for (name, value) in element.attributes { +                match (name.namespace.as_deref(), name.local_name.as_str()) { +                    (None, "from") => from = Some(value.try_into()?), +                    (None, "to") => to = Some(value.try_into()?), +                    (None, "id") => id = Some(value), +                    (None, "version") => version = Some(value), +                    (Some(XML_NS), "lang") => lang = Some(value), +                    _ => return Err(peanuts::Error::UnexpectedAttribute(name)), +                } +            } +            return Ok(Stream { +                from, +                to, +                id, +                version, +                lang, +            }); +        } else { +            return Err(peanuts::Error::IncorrectName(Name { +                namespace, +                local_name, +            })); +        } +    } +} + +impl IntoElement for Stream { +    fn into_element(&self) -> Element { +        let mut namespace_declarations = HashSet::new(); +        namespace_declarations.insert(NamespaceDeclaration { +            prefix: Some("stream".to_string()), +            namespace: XMLNS.to_string(), +        }); +        namespace_declarations.insert(NamespaceDeclaration { +            prefix: None, +            // TODO: don't default to client +            namespace: XMLNS_CLIENT.to_string(), +        }); + +        let mut attributes = HashMap::new(); +        self.from.as_ref().map(|from| { +            attributes.insert( +                Name { +                    namespace: None, +                    local_name: "from".to_string(), +                }, +                from.to_string(), +            ); +        }); +        self.to.as_ref().map(|to| { +            attributes.insert( +                Name { +                    namespace: None, +                    local_name: "to".to_string(), +                }, +                to.to_string(), +            ); +        }); +        self.id.as_ref().map(|id| { +            attributes.insert( +                Name { +                    namespace: None, +                    local_name: "version".to_string(), +                }, +                id.clone(), +            ); +        }); +        self.version.as_ref().map(|version| { +            attributes.insert( +                Name { +                    namespace: None, +                    local_name: "version".to_string(), +                }, +                version.clone(), +            ); +        }); +        self.lang.as_ref().map(|lang| { +            attributes.insert( +                Name { +                    namespace: Some(XML_NS.to_string()), +                    local_name: "lang".to_string(), +                }, +                lang.to_string(), +            ); +        }); + +        Element { +            name: Name { +                namespace: Some(XMLNS.to_string()), +                local_name: "stream".to_string(), +            }, +            namespace_declarations, +            attributes, +            content: Vec::new(), +        } +    }  } -impl<'s> Stream<'s> { +impl<'s> Stream {      pub fn new( -        from: Option<&'s JID>, -        to: Option<&'s JID>, -        id: Option<&'s str>, -        version: Option<&'s str>, -        lang: Option<&'s str>, +        from: Option<JID>, +        to: Option<JID>, +        id: Option<String>, +        version: Option<String>, +        lang: Option<String>,      ) -> Self {          Self {              from, @@ -39,27 +143,18 @@ impl<'s> Stream<'s> {              id,              version,              lang, -            xmlns: XMLNS_CLIENT, -            xmlns_stream: XMLNS,          }      }      /// For initial stream headers, the initiating entity SHOULD include the 'xml:lang' attribute.      /// For privacy, it is better to not set `from` when sending a client stanza over an unencrypted connection. -    pub fn new_client( -        from: Option<&'s JID>, -        to: &'s JID, -        id: Option<&'s str>, -        lang: &'s str, -    ) -> Self { +    pub fn new_client(from: Option<JID>, to: JID, id: Option<String>, lang: String) -> Self {          Self {              from,              to: Some(to),              id, -            version: Some("1.0"), +            version: Some("1.0".to_string()),              lang: Some(lang), -            xmlns: XMLNS_CLIENT, -            xmlns_stream: XMLNS,          }      }  } | 
