aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-11-24 02:04:45 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-11-24 02:04:45 +0000
commit35f164cdb6324c6dfb635f8de93a8221861a5991 (patch)
treef858e55999007046e511acce17b9e35bc1585ba4
parent40024d2dadba9e70edb2f3448204565ce3f68ab7 (diff)
downloadluz-35f164cdb6324c6dfb635f8de93a8221861a5991.tar.gz
luz-35f164cdb6324c6dfb635f8de93a8221861a5991.tar.bz2
luz-35f164cdb6324c6dfb635f8de93a8221861a5991.zip
implement starttls
-rw-r--r--src/connection.rs15
-rw-r--r--src/jabber.rs85
-rw-r--r--src/lib.rs3
-rw-r--r--src/stanza/starttls.rs162
-rw-r--r--src/stanza/stream.rs73
5 files changed, 290 insertions, 48 deletions
diff --git a/src/connection.rs b/src/connection.rs
index 89f382f..2b70747 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -27,8 +27,11 @@ impl Connection {
match self {
Connection::Encrypted(j) => Ok(j),
Connection::Unencrypted(mut j) => {
+ j.start_stream().await?;
info!("upgrading connection to tls");
- Ok(j.starttls().await?)
+ j.get_features().await?;
+ let j = j.starttls().await?;
+ Ok(j)
}
}
}
@@ -179,4 +182,14 @@ mod tests {
async fn connect() {
Connection::connect("blos.sm").await.unwrap();
}
+
+ #[test(tokio::test)]
+ async fn test_tls() {
+ Connection::connect("blos.sm")
+ .await
+ .unwrap()
+ .ensure_tls()
+ .await
+ .unwrap();
+ }
}
diff --git a/src/jabber.rs b/src/jabber.rs
index afe840b..87a2b44 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -1,14 +1,18 @@
use std::str;
use std::sync::Arc;
+use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::SASLConfig;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
+use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, trace};
+use trust_dns_resolver::proto::rr::domain::IntoLabel;
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
-use crate::stanza::stream::Stream;
+use crate::stanza::starttls::{Proceed, StartTls};
+use crate::stanza::stream::{Features, Stream};
use crate::stanza::XML_VERSION;
use crate::Result;
use crate::JID;
@@ -62,7 +66,6 @@ where
// opening stream element
let server = self.server.clone().try_into()?;
let stream = Stream::new_client(None, server, None, "en".to_string());
- // TODO: nicer function to serialize to xml writer
self.writer.write_start(&stream).await?;
// server to client
@@ -72,57 +75,53 @@ where
// receive stream element and validate
let stream: Stream = self.reader.read_start().await?;
+ debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
self.server = from.to_string()
}
Ok(())
}
-}
-// pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
-// Element::read(&mut self.reader).await?.try_into()
-// }
+ pub async fn get_features(&mut self) -> Result<Features> {
+ debug!("getting features");
+ let features: Features = self.reader.read().await?;
+ debug!("got features: {:?}", features);
+ Ok(features)
+ }
+
+ pub fn into_inner(self) -> S {
+ self.reader.into_inner().unsplit(self.writer.into_inner())
+ }
+}
impl Jabber<Unencrypted> {
- pub async fn starttls(&mut self) -> Result<Jabber<Tls>> {
- todo!()
+ pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
+ self.writer
+ .write_full(&StartTls { required: false })
+ .await?;
+ let proceed: Proceed = self.reader.read().await?;
+ debug!("got proceed: {:?}", proceed);
+ let connector = TlsConnector::new().unwrap();
+ let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
+ if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
+ .connect(&self.server, stream)
+ .await
+ {
+ let (read, write) = tokio::io::split(tlsstream);
+ let mut client = Jabber::new(
+ read,
+ write,
+ self.jid.to_owned(),
+ self.auth.to_owned(),
+ self.server.to_owned(),
+ );
+ client.start_stream().await?;
+ return Ok(client);
+ } else {
+ return Err(Error::Connection);
+ }
}
- // let mut starttls_element = BytesStart::new("starttls");
- // starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
- // self.writer
- // .write_event_async(Event::Empty(starttls_element))
- // .await
- // .unwrap();
- // let mut buf = Vec::new();
- // match self.reader.read_event_into_async(&mut buf).await.unwrap() {
- // Event::Empty(e) => match e.name() {
- // QName(b"proceed") => {
- // let connector = TlsConnector::new().unwrap();
- // let stream = self
- // .reader
- // .into_inner()
- // .into_inner()
- // .unsplit(self.writer.into_inner());
- // if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
- // .connect(&self.jabber.server, stream)
- // .await
- // {
- // let (read, write) = tokio::io::split(tlsstream);
- // let reader = Reader::from_reader(BufReader::new(read));
- // let writer = Writer::new(write);
- // let mut client =
- // super::encrypted::JabberClient::new(reader, writer, self.jabber);
- // client.start_stream().await?;
- // return Ok(client);
- // }
- // }
- // QName(_) => return Err(JabberError::TlsNegotiation),
- // },
- // _ => return Err(JabberError::TlsNegotiation),
- // }
- // Err(JabberError::TlsNegotiation)
- // }
}
impl std::fmt::Debug for Jabber<Tls> {
diff --git a/src/lib.rs b/src/lib.rs
index 306b0fd..88b91a6 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -8,9 +8,6 @@ pub mod jabber;
pub mod jid;
pub mod stanza;
-#[macro_use]
-extern crate lazy_static;
-
pub use connection::Connection;
pub use error::Error;
pub use jabber::Jabber;
diff --git a/src/stanza/starttls.rs b/src/stanza/starttls.rs
index 8b13789..874ae66 100644
--- a/src/stanza/starttls.rs
+++ b/src/stanza/starttls.rs
@@ -1 +1,163 @@
+use std::collections::{HashMap, HashSet};
+use peanuts::{
+ element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration},
+ Element,
+};
+
+pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
+
+#[derive(Debug)]
+pub struct StartTls {
+ pub required: bool,
+}
+
+impl IntoElement for StartTls {
+ fn into_element(&self) -> peanuts::Element {
+ let content;
+ if self.required == true {
+ let element = Content::Element(Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "required".to_string(),
+ },
+ namespace_declarations: HashSet::new(),
+ attributes: HashMap::new(),
+ content: Vec::new(),
+ });
+ content = vec![element];
+ } else {
+ content = Vec::new();
+ }
+ let mut namespace_declarations = HashSet::new();
+ namespace_declarations.insert(NamespaceDeclaration {
+ prefix: None,
+ namespace: XMLNS.to_string(),
+ });
+ Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "starttls".to_string(),
+ },
+ namespace_declarations,
+ attributes: HashMap::new(),
+ content,
+ }
+ }
+}
+
+impl FromElement for StartTls {
+ fn from_element(element: peanuts::Element) -> peanuts::Result<Self> {
+ let Name {
+ namespace,
+ local_name,
+ } = element.name;
+ if namespace.as_deref() == Some(XMLNS) && &local_name == "starttls" {
+ let mut required = false;
+ if element.content.len() == 1 {
+ match element.content.first().unwrap() {
+ Content::Element(element) => {
+ let Name {
+ namespace,
+ local_name,
+ } = &element.name;
+
+ if namespace.as_deref() == Some(XMLNS) && local_name == "required" {
+ required = true
+ } else {
+ return Err(peanuts::Error::UnexpectedElement(element.name.clone()));
+ }
+ }
+ c => return Err(peanuts::Error::UnexpectedContent((*c).clone())),
+ }
+ } else {
+ return Err(peanuts::Error::UnexpectedNumberOfContents(
+ element.content.len(),
+ ));
+ }
+ return Ok(StartTls { required });
+ } else {
+ return Err(peanuts::Error::IncorrectName(Name {
+ namespace,
+ local_name,
+ }));
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Proceed;
+
+impl IntoElement for Proceed {
+ fn into_element(&self) -> Element {
+ let mut namespace_declarations = HashSet::new();
+ namespace_declarations.insert(NamespaceDeclaration {
+ prefix: None,
+ namespace: XMLNS.to_string(),
+ });
+ Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "proceed".to_string(),
+ },
+ namespace_declarations,
+ attributes: HashMap::new(),
+ content: Vec::new(),
+ }
+ }
+}
+
+impl FromElement for Proceed {
+ fn from_element(element: Element) -> peanuts::Result<Self> {
+ let Name {
+ namespace,
+ local_name,
+ } = element.name;
+ if namespace.as_deref() == Some(XMLNS) && &local_name == "proceed" {
+ return Ok(Proceed);
+ } else {
+ return Err(peanuts::Error::IncorrectName(Name {
+ namespace,
+ local_name,
+ }));
+ }
+ }
+}
+
+pub struct Failure;
+
+impl IntoElement for Failure {
+ fn into_element(&self) -> Element {
+ let mut namespace_declarations = HashSet::new();
+ namespace_declarations.insert(NamespaceDeclaration {
+ prefix: None,
+ namespace: XMLNS.to_string(),
+ });
+ Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "failure".to_string(),
+ },
+ namespace_declarations,
+ attributes: HashMap::new(),
+ content: Vec::new(),
+ }
+ }
+}
+
+impl FromElement for Failure {
+ fn from_element(element: Element) -> peanuts::Result<Self> {
+ let Name {
+ namespace,
+ local_name,
+ } = element.name;
+ if namespace.as_deref() == Some(XMLNS) && &local_name == "failure" {
+ return Ok(Failure);
+ } else {
+ return Err(peanuts::Error::IncorrectName(Name {
+ namespace,
+ local_name,
+ }));
+ }
+ }
+}
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
index ac4badc..4516682 100644
--- a/src/stanza/stream.rs
+++ b/src/stanza/stream.rs
@@ -6,12 +6,15 @@ use peanuts::{element::Name, Element};
use crate::{Error, JID};
+use super::starttls::StartTls;
+
pub const XMLNS: &str = "http://etherx.jabber.org/streams";
pub const XMLNS_CLIENT: &str = "jabber:client";
// MUST be qualified by stream namespace
// #[derive(XmlSerialize, XmlDeserialize)]
// #[peanuts(xmlns = XMLNS)]
+#[derive(Debug)]
pub struct Stream {
pub from: Option<JID>,
to: Option<JID>,
@@ -93,7 +96,7 @@ impl IntoElement for Stream {
attributes.insert(
Name {
namespace: None,
- local_name: "version".to_string(),
+ local_name: "id".to_string(),
},
id.clone(),
);
@@ -158,3 +161,71 @@ impl<'s> Stream {
}
}
}
+
+#[derive(Debug)]
+pub struct Features {
+ features: Vec<Feature>,
+}
+
+impl IntoElement for Features {
+ fn into_element(&self) -> Element {
+ let mut content = Vec::new();
+ for feature in &self.features {
+ match feature {
+ Feature::StartTls(start_tls) => {
+ content.push(Content::Element(start_tls.into_element()))
+ }
+ Feature::Sasl => {}
+ Feature::Bind => {}
+ Feature::Unknown => {}
+ }
+ }
+ Element {
+ name: Name {
+ namespace: Some(XMLNS.to_string()),
+ local_name: "features".to_string(),
+ },
+ namespace_declarations: HashSet::new(),
+ attributes: HashMap::new(),
+ content,
+ }
+ }
+}
+
+impl FromElement for Features {
+ fn from_element(element: Element) -> peanuts::Result<Self> {
+ let Name {
+ namespace,
+ local_name,
+ } = element.name;
+ if namespace.as_deref() == Some(XMLNS) && &local_name == "features" {
+ let mut features = Vec::new();
+ for feature in element.content {
+ match feature {
+ Content::Element(element) => {
+ if let Ok(start_tls) = FromElement::from_element(element) {
+ features.push(Feature::StartTls(start_tls))
+ } else {
+ features.push(Feature::Unknown)
+ }
+ }
+ c => return Err(peanuts::Error::UnexpectedContent(c.clone())),
+ }
+ }
+ return Ok(Self { features });
+ } else {
+ return Err(peanuts::Error::IncorrectName(Name {
+ namespace,
+ local_name,
+ }));
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum Feature {
+ StartTls(StartTls),
+ Sasl,
+ Bind,
+ Unknown,
+}