summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/client/encrypted.rs59
-rw-r--r--src/client/mod.rs40
-rw-r--r--src/client/unencrypted.rs135
-rw-r--r--src/error.rs7
-rw-r--r--src/jabber.rs131
-rw-r--r--src/lib.rs187
-rw-r--r--src/stanza/mod.rs1
-rw-r--r--src/stanza/stream.rs36
8 files changed, 437 insertions, 159 deletions
diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs
new file mode 100644
index 0000000..08439b2
--- /dev/null
+++ b/src/client/encrypted.rs
@@ -0,0 +1,59 @@
+use quick_xml::{
+ events::{BytesDecl, BytesStart, Event},
+ Reader, Writer,
+};
+use tokio::io::{BufReader, ReadHalf, WriteHalf};
+use tokio::net::TcpStream;
+use tokio_native_tls::TlsStream;
+
+use crate::Jabber;
+use crate::Result;
+
+pub struct JabberClient<'j> {
+ reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
+ writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
+ jabber: &'j mut Jabber<'j>,
+}
+
+impl<'j> JabberClient<'j> {
+ pub fn new(
+ reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>,
+ writer: Writer<WriteHalf<TlsStream<TcpStream>>>,
+ jabber: &'j mut Jabber<'j>,
+ ) -> Self {
+ Self {
+ reader,
+ writer,
+ jabber,
+ }
+ }
+
+ pub async fn start_stream(&mut self) -> Result<()> {
+ let declaration = BytesDecl::new("1.0", None, None);
+ let mut stream_element = BytesStart::new("stream:stream");
+ stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
+ stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
+ stream_element.push_attribute(("version", "1.0"));
+ stream_element.push_attribute(("xml:lang", "en"));
+ stream_element.push_attribute(("xmlns", "jabber:client"));
+ stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
+ self.writer
+ .write_event_async(Event::Decl(declaration))
+ .await;
+ self.writer
+ .write_event_async(Event::Start(stream_element))
+ .await
+ .unwrap();
+ let mut buf = Vec::new();
+ loop {
+ match self.reader.read_event_into_async(&mut buf).await.unwrap() {
+ Event::Start(e) => {
+ println!("{:?}", e);
+ break;
+ }
+ e => println!("decl: {:?}", e),
+ };
+ }
+ Ok(())
+ }
+}
diff --git a/src/client/mod.rs b/src/client/mod.rs
new file mode 100644
index 0000000..fe3dd34
--- /dev/null
+++ b/src/client/mod.rs
@@ -0,0 +1,40 @@
+pub mod encrypted;
+pub mod unencrypted;
+
+// use async_trait::async_trait;
+
+use crate::stanza::stream::StreamFeature;
+use crate::JabberError;
+use crate::Result;
+
+pub enum JabberClientType<'j> {
+ Encrypted(encrypted::JabberClient<'j>),
+ Unencrypted(unencrypted::JabberClient<'j>),
+}
+
+impl<'j> JabberClientType<'j> {
+ pub async fn ensure_tls(self) -> Result<encrypted::JabberClient<'j>> {
+ match self {
+ Self::Encrypted(mut c) => {
+ c.start_stream();
+ Ok(c)
+ }
+ Self::Unencrypted(mut c) => {
+ c.start_stream().await?;
+ let features = c.get_features().await?;
+ if features.contains(&StreamFeature::StartTls) {
+ Ok(c.starttls().await?)
+ } else {
+ Err(JabberError::StartTlsUnavailable)
+ }
+ }
+ }
+ }
+}
+
+// TODO: jabber client trait over both client types
+// #[async_trait]
+// pub trait JabberTrait {
+// async fn start_stream(&mut self) -> Result<()>;
+// async fn get_features(&self) -> Result<Vec<StreamFeatures>>;
+// }
diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs
new file mode 100644
index 0000000..7528b14
--- /dev/null
+++ b/src/client/unencrypted.rs
@@ -0,0 +1,135 @@
+use std::str;
+
+use quick_xml::{
+ de::Deserializer,
+ events::{BytesDecl, BytesStart, Event},
+ name::QName,
+ Reader, Writer,
+};
+use serde::Deserialize;
+use tokio::io::{BufReader, ReadHalf, WriteHalf};
+use tokio::net::TcpStream;
+use tokio_native_tls::native_tls::TlsConnector;
+
+use crate::Result;
+use crate::{error::JabberError, stanza::stream::StreamFeature};
+use crate::{stanza::stream::StreamFeatures, Jabber};
+
+pub struct JabberClient<'j> {
+ reader: Reader<BufReader<ReadHalf<TcpStream>>>,
+ writer: Writer<WriteHalf<TcpStream>>,
+ jabber: &'j mut Jabber<'j>,
+}
+
+impl<'j> JabberClient<'j> {
+ pub fn new(
+ reader: Reader<BufReader<ReadHalf<TcpStream>>>,
+ writer: Writer<WriteHalf<TcpStream>>,
+ jabber: &'j mut Jabber<'j>,
+ ) -> Self {
+ Self {
+ reader,
+ writer,
+ jabber,
+ }
+ }
+
+ pub async fn start_stream(&mut self) -> Result<()> {
+ let declaration = BytesDecl::new("1.0", None, None);
+ let mut stream_element = BytesStart::new("stream:stream");
+ stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes()));
+ stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes()));
+ stream_element.push_attribute(("version", "1.0"));
+ stream_element.push_attribute(("xml:lang", "en"));
+ stream_element.push_attribute(("xmlns", "jabber:client"));
+ stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams"));
+ self.writer
+ .write_event_async(Event::Decl(declaration))
+ .await;
+ self.writer
+ .write_event_async(Event::Start(stream_element))
+ .await
+ .unwrap();
+ let mut buf = Vec::new();
+ loop {
+ match self.reader.read_event_into_async(&mut buf).await.unwrap() {
+ Event::Start(e) => {
+ println!("{:?}", e);
+ break;
+ }
+ Event::Decl(e) => println!("decl: {:?}", e),
+ _ => return Err(JabberError::BadStream),
+ }
+ }
+ Ok(())
+ }
+
+ pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> {
+ let mut buf = Vec::new();
+ let mut txt = Vec::new();
+ let mut loop_end = false;
+ while !loop_end {
+ match self.reader.read_event_into_async(&mut buf).await.unwrap() {
+ Event::End(e) => {
+ if e.name() == QName(b"stream:features") {
+ loop_end = true;
+ }
+ }
+ _ => (),
+ }
+ txt.push(b'<');
+ txt = txt
+ .into_iter()
+ .chain(buf.to_owned())
+ .chain(vec![b'>'])
+ .collect();
+ buf.clear();
+ }
+ println!("{:?}", txt);
+ let decoded = str::from_utf8(&txt).unwrap();
+ println!("decoded: {:?}", decoded);
+ let mut deserializer = Deserializer::from_str(decoded);
+ // let mut deserializer = Deserializer::from_str(txt);
+ let features = StreamFeatures::deserialize(&mut deserializer).unwrap();
+ println!("{:?}", features);
+ Ok(features.features)
+ }
+
+ pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> {
+ let mut starttls_element = BytesStart::new("starttls");
+ starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls"));
+ self.writer
+ .write_event_async(Event::Empty(starttls_element))
+ .await
+ .unwrap();
+ let mut buf = Vec::new();
+ match self.reader.read_event_into_async(&mut buf).await.unwrap() {
+ Event::Empty(e) => match e.name() {
+ QName(b"proceed") => {
+ let connector = TlsConnector::new().unwrap();
+ let stream = self
+ .reader
+ .into_inner()
+ .into_inner()
+ .unsplit(self.writer.into_inner());
+ if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
+ .connect(&self.jabber.server, stream)
+ .await
+ {
+ let (read, write) = tokio::io::split(tlsstream);
+ let reader = Reader::from_reader(BufReader::new(read));
+ let writer = Writer::new(write);
+ return Ok(super::encrypted::JabberClient::new(
+ reader,
+ writer,
+ self.jabber,
+ ));
+ }
+ }
+ QName(_) => return Err(JabberError::TlsNegotiation),
+ },
+ _ => return Err(JabberError::TlsNegotiation),
+ }
+ Err(JabberError::TlsNegotiation)
+ }
+}
diff --git a/src/error.rs b/src/error.rs
new file mode 100644
index 0000000..a632537
--- /dev/null
+++ b/src/error.rs
@@ -0,0 +1,7 @@
+#[derive(Debug)]
+pub enum JabberError {
+ ConnectionError,
+ BadStream,
+ StartTlsUnavailable,
+ TlsNegotiation,
+}
diff --git a/src/jabber.rs b/src/jabber.rs
new file mode 100644
index 0000000..a1f6272
--- /dev/null
+++ b/src/jabber.rs
@@ -0,0 +1,131 @@
+use std::marker::PhantomData;
+use std::net::{IpAddr, SocketAddr};
+use std::str::FromStr;
+
+use quick_xml::{Reader, Writer};
+use tokio::io::BufReader;
+use tokio::net::TcpStream;
+use tokio_native_tls::native_tls::TlsConnector;
+
+use crate::client;
+use crate::client::JabberClientType;
+use crate::jid::JID;
+use crate::{JabberError, Result};
+
+pub struct Jabber<'j> {
+ pub jid: JID,
+ pub password: String,
+ pub server: String,
+ _marker: PhantomData<&'j ()>,
+}
+
+impl<'j> Jabber<'j> {
+ pub fn new(jid: JID, password: String) -> Self {
+ let server = jid.domainpart.clone();
+ Self {
+ jid,
+ password,
+ server,
+ _marker: PhantomData,
+ }
+ }
+
+ async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> {
+ let mut socket_addrs = Vec::new();
+
+ // if it's a socket/ip then just return that
+
+ // socket
+ if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
+ match socket_addr.port() {
+ 5223 => socket_addrs.push((socket_addr, true)),
+ _ => socket_addrs.push((socket_addr, false)),
+ }
+
+ return socket_addrs;
+ }
+ // ip
+ if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ return socket_addrs;
+ }
+
+ // otherwise resolve
+ if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
+ }
+ });
+ }
+ }
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
+ }
+ });
+ }
+ }
+
+ // in case cannot connect through SRV records
+ resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ }
+ });
+ }
+ socket_addrs
+ }
+
+ pub async fn connect(&'j mut self) -> Result<JabberClientType> {
+ for (socket_addr, is_tls) in self.get_sockets().await {
+ println!("trying {}", socket_addr);
+ match is_tls {
+ true => {
+ let socket = TcpStream::connect(socket_addr).await.unwrap();
+ let connector = TlsConnector::new().unwrap();
+ if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector)
+ .connect(&self.server, socket)
+ .await
+ {
+ let (read, write) = tokio::io::split(stream);
+ let reader = Reader::from_reader(BufReader::new(read));
+ let writer = Writer::new(write);
+ return Ok(JabberClientType::Encrypted(
+ client::encrypted::JabberClient::new(reader, writer, self),
+ ));
+ }
+ }
+ false => {
+ if let Ok(stream) = TcpStream::connect(socket_addr).await {
+ let (read, write) = tokio::io::split(stream);
+ let reader = Reader::from_reader(BufReader::new(read));
+ let writer = Writer::new(write);
+ return Ok(JabberClientType::Unencrypted(
+ client::unencrypted::JabberClient::new(reader, writer, self),
+ ));
+ }
+ }
+ }
+ }
+ Err(JabberError::ConnectionError)
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 10c7172..7f1433d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,174 +1,43 @@
-// TODO: logging (dropped errors)
#![allow(unused_must_use)]
-use std::{
- net::{IpAddr, SocketAddr},
- str::FromStr,
-};
-
-use jid::JID;
-use quick_xml::{Reader, Writer};
-use tokio::net::{
- tcp::{OwnedReadHalf, OwnedWriteHalf},
- TcpStream,
-};
+// TODO: logging (dropped errors)
+pub mod client;
+pub mod error;
+pub mod jabber;
pub mod jid;
+pub mod stanza;
-pub struct JabberData {
- jid: jid::JID,
- password: String,
-}
-
-impl JabberData {
- pub fn new(jid: JID, password: String) -> Self {
- Self { jid, password }
- }
-
- async fn get_sockets(&self) -> Vec<SocketAddr> {
- let mut socket_addrs = Vec::new();
-
- // if it's a socket/ip then just return that
-
- // socket
- if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) {
- socket_addrs.push(socket_addr);
- return socket_addrs;
- }
- // ip
- if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) {
- socket_addrs.push(SocketAddr::new(ip, 5222));
- socket_addrs.push(SocketAddr::new(ip, 5223));
- return socket_addrs;
- }
-
- // if port specified return name resolutions with specified port
-
- // otherwise resolve
- if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
- if let Ok(lookup) = resolver
- .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart))
- .await
- {
- for srv in lookup {
- resolver
- .lookup_ip(srv.target().to_owned())
- .await
- .map(|ips| {
- for ip in ips {
- socket_addrs.push(SocketAddr::new(ip, srv.port()))
- }
- });
- }
- }
- if let Ok(lookup) = resolver
- .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart))
- .await
- {
- for srv in lookup {
- resolver
- .lookup_ip(srv.target().to_owned())
- .await
- .map(|ips| {
- for ip in ips {
- socket_addrs.push(SocketAddr::new(ip, srv.port()))
- }
- });
- }
- }
-
- // in case cannot connect through SRV records
- resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| {
- for ip in ips {
- socket_addrs.push(SocketAddr::new(ip, 5222));
- socket_addrs.push(SocketAddr::new(ip, 5223));
- }
- });
- }
-
- socket_addrs
- }
-}
-
-pub struct Jabber {
- reader: Reader<OwnedReadHalf>,
- writer: Writer<OwnedWriteHalf>,
- data: JabberData,
-}
-
-#[derive(Debug)]
-pub enum JabberError {
- NotConnected,
-}
+pub use client::encrypted::JabberClient;
+pub use error::JabberError;
+pub use jabber::Jabber;
+pub use jid::JID;
-impl Jabber {
- pub async fn connect(data: JabberData) -> Result<Self, JabberError> {
- for socket_addr in data.get_sockets().await {
- println!("trying {}", socket_addr);
- if let Ok(stream) = TcpStream::connect(socket_addr).await {
- println!("connected to {}", socket_addr);
- let (read, write) = stream.into_split();
- return Ok(Self {
- reader: Reader::from_reader(read),
- writer: Writer::new(write),
- data,
- });
- }
- }
- Err(JabberError::NotConnected)
- }
-
- async fn reconnect(&mut self) {
- for socket_addr in self.data.get_sockets().await {
- println!("trying {}", socket_addr);
- if let Ok(stream) = TcpStream::connect(socket_addr).await {
- println!("connected to {}", socket_addr);
- let (read, write) = stream.into_split();
- self.reader = Reader::from_reader(read);
- self.writer = Writer::new(write);
- return;
- }
- }
- println!("could not connect")
- }
-
- async fn begin_stream(&mut self) -> Result<(), JabberError> {
- todo!()
- }
-
- async fn starttls() -> Result<(), JabberError> {
- todo!()
- }
-
- async fn directtls() -> Result<(), JabberError> {
- todo!()
- }
-
- async fn auth(&mut self) -> Result<(), JabberError> {
- todo!()
- }
-
- async fn close(&mut self) {}
-}
+pub type Result<T> = std::result::Result<T, JabberError>;
#[cfg(test)]
mod tests {
- use crate::jid::JID;
+ use std::str::FromStr;
- use super::*;
+ use crate::Jabber;
+ use crate::JID;
- #[tokio::test]
- async fn get_sockets() {
- let data = JabberData::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned());
- println!("{:?}", data.get_sockets().await)
- }
+ // #[tokio::test]
+ // async fn get_sockets() {
+ // let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned());
+ // println!("{:?}", jabber.get_sockets().await)
+ // }
#[tokio::test]
async fn connect() {
- Jabber::connect(JabberData::new(
- JID::from_str("cel@blos.sm").unwrap(),
- "password".to_owned(),
- ))
- .await
- .unwrap();
+ Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned())
+ .connect()
+ .await
+ .unwrap()
+ .ensure_tls()
+ .await
+ .unwrap()
+ .start_stream()
+ .await
+ .unwrap();
}
}
diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs
new file mode 100644
index 0000000..baf29e0
--- /dev/null
+++ b/src/stanza/mod.rs
@@ -0,0 +1 @@
+pub mod stream;
diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs
new file mode 100644
index 0000000..dde741d
--- /dev/null
+++ b/src/stanza/stream.rs
@@ -0,0 +1,36 @@
+use serde::{Deserialize, Serialize};
+
+#[derive(Serialize, Deserialize)]
+#[serde(rename = "stream:stream")]
+struct Stream {
+ #[serde(rename = "@from")]
+ from: Option<String>,
+ #[serde(rename = "@id")]
+ id: Option<String>,
+ #[serde(rename = "@to")]
+ to: Option<String>,
+ #[serde(rename = "@version")]
+ version: Option<f32>,
+ #[serde(rename = "@xml:lang")]
+ lang: Option<String>,
+ #[serde(rename = "@xmlns")]
+ namespace: Option<String>,
+ #[serde(rename = "@xmlns:stream")]
+ stream_namespace: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(rename = "stream:features")]
+pub struct StreamFeatures {
+ #[serde(rename = "$value")]
+ pub features: Vec<StreamFeature>,
+}
+
+#[derive(Deserialize, PartialEq, Debug)]
+pub enum StreamFeature {
+ #[serde(rename = "starttls")]
+ StartTls,
+ // TODO: other stream features
+ Sasl,
+ Bind,
+}