diff options
Diffstat (limited to '')
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/connection.rs | 9 | ||||
-rw-r--r-- | src/jabber.rs | 132 | ||||
-rw-r--r-- | src/lib.rs | 2 |
4 files changed, 74 insertions, 70 deletions
@@ -15,6 +15,7 @@ quick-xml = { git = "https://github.com/tafia/quick-xml.git", features = ["async # TODO: remove unneeded features rsasl = { version = "2", default_features = true, features = ["provider_base64", "plain", "config_builder"] } serde = "1.0.180" +serde_with = "3.4.0" tokio = { version = "1.28", features = ["full"] } tokio-native-tls = "0.3.1" tracing = "0.1.40" diff --git a/src/connection.rs b/src/connection.rs index ccc2ae7..b42711e 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -15,16 +15,21 @@ use crate::Result; pub type Tls = TlsStream<TcpStream>; pub type Unencrypted = TcpStream; +#[derive(Debug)] pub enum Connection { Encrypted(Jabber<Tls>), Unencrypted(Jabber<Unencrypted>), } impl Connection { + #[instrument] pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { match self { Connection::Encrypted(j) => Ok(j), - Connection::Unencrypted(j) => Ok(j.starttls().await?), + Connection::Unencrypted(mut j) => { + info!("upgrading connection to tls"); + Ok(j.starttls().await?) + } } } @@ -36,7 +41,7 @@ impl Connection { // } #[instrument] - async fn connect(server: &str) -> Result<Self> { + pub async fn connect(server: &str) -> Result<Self> { info!("connecting to {}", server); let sockets = Self::get_sockets(&server).await; debug!("discovered sockets: {:?}", sockets); diff --git a/src/jabber.rs b/src/jabber.rs index 3583d19..1436bfa 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,9 +1,11 @@ +use std::str; use std::sync::Arc; use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; use rsasl::prelude::SASLConfig; use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tracing::{debug, info, trace}; use crate::connection::{Tls, Unencrypted}; use crate::error::JabberError; @@ -17,7 +19,7 @@ where S: AsyncRead + AsyncWrite + Unpin, { reader: NsReader<BufReader<ReadHalf<S>>>, - writer: Writer<WriteHalf<S>>, + writer: WriteHalf<S>, jid: Option<JID>, auth: Option<Arc<SASLConfig>>, server: String, @@ -35,7 +37,6 @@ where server: String, ) -> Self { let reader = NsReader::from_reader(BufReader::new(reader)); - let writer = Writer::new(writer); Self { reader, writer, @@ -49,112 +50,71 @@ where impl<S> Jabber<S> where S: AsyncRead + AsyncWrite + Unpin, - Writer<tokio::io::WriteHalf<S>>: AsyncWriteExt, - Writer<tokio::io::WriteHalf<S>>: AsyncWrite, { + // pub async fn negotiate(self) -> Result<Jabber<S>> {} + pub async fn start_stream(&mut self) -> Result<()> { // client to server // declaration - self.writer.write_event_async(DECLARATION.clone()).await?; + let mut xmlwriter = Writer::new(&mut self.writer); + xmlwriter.write_event_async(DECLARATION.clone()).await?; // opening stream element let server = &self.server.to_owned().try_into()?; let stream_element = Stream::new_client(None, server, None, "en"); // TODO: nicer function to serialize to xml writer let mut buffer = String::new(); - let ser = Serializer::new(&mut buffer); + let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name"); stream_element.serialize(ser).unwrap(); - self.writer.write_all(buffer.as_bytes()); + trace!("sent: {}", buffer); + self.writer.write_all(buffer.as_bytes()).await.unwrap(); // server to client // may or may not send a declaration let mut buf = Vec::new(); let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; + trace!("received: {:?}", first_event); match first_event { (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { if let Ok(version) = e.version() { if version.as_ref() == b"1.0" { - first_event = self.reader.read_resolved_event_into_async(&mut buf).await? + first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; + trace!("received: {:?}", first_event); } else { // todo: error todo!() } } else { - first_event = self.reader.read_resolved_event_into_async(&mut buf).await? + first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; + trace!("received: {:?}", first_event); } } _ => (), } // receive stream element and validate - let stream_response: Stream; match first_event { (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { - // stream_response = Stream::new( - // e.try_get_attribute("from")?.try_map(|attribute| { - // str::from_utf8(attribute.value.as_ref())? - // .try_into()? - // .as_ref() - // })?, - // e.try_get_attribute("to")?.try_map(|attribute| { - // str::from_utf8(attribute.value.as_ref())? - // .try_into()? - // .as_ref() - // })?, - // e.try_get_attribute("id")?.try_map(|attribute| { - // str::from_utf8(attribute.value.as_ref())? - // .try_into()? - // .as_ref() - // })?, - // e.try_get_attribute("version")?.try_map(|attribute| { - // str::from_utf8(attribute.value.as_ref())? - // .try_into()? - // .as_ref() - // })?, - // e.try_get_attribute("lang")?.try_map(|attribute| { - // str::from_utf8(attribute.value.as_ref())? - // .try_into()? - // .as_ref() - // })?, - // ); + e.attributes().try_for_each(|attr| -> Result<()> { + let attr = attr?; + match attr.key.into_inner() { + b"from" => { + self.server = str::from_utf8(&attr.value)?.to_owned(); + Ok(()) + } + _ => Ok(()), + } + }); return Ok(()); } else { return Err(JabberError::BadStream); } } // TODO: errors for incorrect namespace - (quick_xml::name::ResolveResult::Unbound, Event::Decl(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Start(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::End(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Empty(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Text(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::CData(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Comment(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Decl(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::PI(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::DocType(_)) => todo!(), - (quick_xml::name::ResolveResult::Unknown(_), Event::Eof) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::Start(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::End(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::Empty(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::Text(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::CData(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::Comment(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::PI(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::DocType(_)) => todo!(), - (quick_xml::name::ResolveResult::Unbound, Event::Eof) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::End(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::Empty(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::Text(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::CData(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::Comment(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::Decl(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::PI(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::DocType(_)) => todo!(), - (quick_xml::name::ResolveResult::Bound(_), Event::Eof) => todo!(), + _ => Err(JabberError::BadStream), } } } @@ -164,7 +124,7 @@ where // } impl Jabber<Unencrypted> { - pub async fn starttls(mut self) -> Result<Jabber<Tls>> { + pub async fn starttls(&mut self) -> Result<Jabber<Tls>> { todo!() } // let mut starttls_element = BytesStart::new("starttls"); @@ -203,3 +163,41 @@ impl Jabber<Unencrypted> { // Err(JabberError::TlsNegotiation) // } } + +impl std::fmt::Debug for Jabber<Tls> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"tls") + .field("jid", &self.jid) + .field("auth", &self.auth) + .field("server", &self.server) + .finish() + } +} + +impl std::fmt::Debug for Jabber<Unencrypted> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Jabber") + .field("connection", &"unencrypted") + .field("jid", &self.jid) + .field("auth", &self.auth) + .field("server", &self.server) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::connection::Connection; + use test_log::test; + + #[test(tokio::test)] + async fn start_stream() { + let connection = Connection::connect("blos.sm").await.unwrap(); + match connection { + Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), + Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), + } + } +} @@ -1,5 +1,5 @@ #![allow(unused_must_use)] -#![feature(let_chains)] +// #![feature(let_chains)] // TODO: logging (dropped errors) pub mod connection; |