From 6a5e39c60ad74c1cba84daa7c845c8f0237a5d28 Mon Sep 17 00:00:00 2001 From: cel 🌸 Date: Mon, 19 Jun 2023 19:23:54 +0100 Subject: implement starttls --- src/client/encrypted.rs | 59 ++++++++++++++++++++ src/client/mod.rs | 40 ++++++++++++++ src/client/unencrypted.rs | 135 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 234 insertions(+) create mode 100644 src/client/encrypted.rs create mode 100644 src/client/mod.rs create mode 100644 src/client/unencrypted.rs (limited to 'src/client') diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs new file mode 100644 index 0000000..08439b2 --- /dev/null +++ b/src/client/encrypted.rs @@ -0,0 +1,59 @@ +use quick_xml::{ + events::{BytesDecl, BytesStart, Event}, + Reader, Writer, +}; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::TlsStream; + +use crate::Jabber; +use crate::Result; + +pub struct JabberClient<'j> { + reader: Reader>>>, + writer: Writer>>, + jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { + pub fn new( + reader: Reader>>>, + writer: Writer>>, + jabber: &'j mut Jabber<'j>, + ) -> Self { + Self { + reader, + writer, + jabber, + } + } + + pub async fn start_stream(&mut self) -> Result<()> { + let declaration = BytesDecl::new("1.0", None, None); + let mut stream_element = BytesStart::new("stream:stream"); + stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); + stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); + stream_element.push_attribute(("version", "1.0")); + stream_element.push_attribute(("xml:lang", "en")); + stream_element.push_attribute(("xmlns", "jabber:client")); + stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); + self.writer + .write_event_async(Event::Decl(declaration)) + .await; + self.writer + .write_event_async(Event::Start(stream_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + loop { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Start(e) => { + println!("{:?}", e); + break; + } + e => println!("decl: {:?}", e), + }; + } + Ok(()) + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..fe3dd34 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,40 @@ +pub mod encrypted; +pub mod unencrypted; + +// use async_trait::async_trait; + +use crate::stanza::stream::StreamFeature; +use crate::JabberError; +use crate::Result; + +pub enum JabberClientType<'j> { + Encrypted(encrypted::JabberClient<'j>), + Unencrypted(unencrypted::JabberClient<'j>), +} + +impl<'j> JabberClientType<'j> { + pub async fn ensure_tls(self) -> Result> { + match self { + Self::Encrypted(mut c) => { + c.start_stream(); + Ok(c) + } + Self::Unencrypted(mut c) => { + c.start_stream().await?; + let features = c.get_features().await?; + if features.contains(&StreamFeature::StartTls) { + Ok(c.starttls().await?) + } else { + Err(JabberError::StartTlsUnavailable) + } + } + } + } +} + +// TODO: jabber client trait over both client types +// #[async_trait] +// pub trait JabberTrait { +// async fn start_stream(&mut self) -> Result<()>; +// async fn get_features(&self) -> Result>; +// } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs new file mode 100644 index 0000000..7528b14 --- /dev/null +++ b/src/client/unencrypted.rs @@ -0,0 +1,135 @@ +use std::str; + +use quick_xml::{ + de::Deserializer, + events::{BytesDecl, BytesStart, Event}, + name::QName, + Reader, Writer, +}; +use serde::Deserialize; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; + +use crate::Result; +use crate::{error::JabberError, stanza::stream::StreamFeature}; +use crate::{stanza::stream::StreamFeatures, Jabber}; + +pub struct JabberClient<'j> { + reader: Reader>>, + writer: Writer>, + jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { + pub fn new( + reader: Reader>>, + writer: Writer>, + jabber: &'j mut Jabber<'j>, + ) -> Self { + Self { + reader, + writer, + jabber, + } + } + + pub async fn start_stream(&mut self) -> Result<()> { + let declaration = BytesDecl::new("1.0", None, None); + let mut stream_element = BytesStart::new("stream:stream"); + stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); + stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); + stream_element.push_attribute(("version", "1.0")); + stream_element.push_attribute(("xml:lang", "en")); + stream_element.push_attribute(("xmlns", "jabber:client")); + stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); + self.writer + .write_event_async(Event::Decl(declaration)) + .await; + self.writer + .write_event_async(Event::Start(stream_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + loop { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Start(e) => { + println!("{:?}", e); + break; + } + Event::Decl(e) => println!("decl: {:?}", e), + _ => return Err(JabberError::BadStream), + } + } + Ok(()) + } + + pub async fn get_features(&mut self) -> Result> { + let mut buf = Vec::new(); + let mut txt = Vec::new(); + let mut loop_end = false; + while !loop_end { + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::End(e) => { + if e.name() == QName(b"stream:features") { + loop_end = true; + } + } + _ => (), + } + txt.push(b'<'); + txt = txt + .into_iter() + .chain(buf.to_owned()) + .chain(vec![b'>']) + .collect(); + buf.clear(); + } + println!("{:?}", txt); + let decoded = str::from_utf8(&txt).unwrap(); + println!("decoded: {:?}", decoded); + let mut deserializer = Deserializer::from_str(decoded); + // let mut deserializer = Deserializer::from_str(txt); + let features = StreamFeatures::deserialize(&mut deserializer).unwrap(); + println!("{:?}", features); + Ok(features.features) + } + + pub async fn starttls(mut self) -> Result> { + let mut starttls_element = BytesStart::new("starttls"); + starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); + self.writer + .write_event_async(Event::Empty(starttls_element)) + .await + .unwrap(); + let mut buf = Vec::new(); + match self.reader.read_event_into_async(&mut buf).await.unwrap() { + Event::Empty(e) => match e.name() { + QName(b"proceed") => { + let connector = TlsConnector::new().unwrap(); + let stream = self + .reader + .into_inner() + .into_inner() + .unsplit(self.writer.into_inner()); + if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) + .connect(&self.jabber.server, stream) + .await + { + let (read, write) = tokio::io::split(tlsstream); + let reader = Reader::from_reader(BufReader::new(read)); + let writer = Writer::new(write); + return Ok(super::encrypted::JabberClient::new( + reader, + writer, + self.jabber, + )); + } + } + QName(_) => return Err(JabberError::TlsNegotiation), + }, + _ => return Err(JabberError::TlsNegotiation), + } + Err(JabberError::TlsNegotiation) + } +} -- cgit