aboutsummaryrefslogtreecommitdiffstats
path: root/jabber/src
diff options
context:
space:
mode:
Diffstat (limited to 'jabber/src')
-rw-r--r--jabber/src/client.rs217
-rw-r--r--jabber/src/connection.rs184
-rw-r--r--jabber/src/error.rs78
-rw-r--r--jabber/src/jabber_stream.rs393
-rw-r--r--jabber/src/lib.rs34
5 files changed, 906 insertions, 0 deletions
diff --git a/jabber/src/client.rs b/jabber/src/client.rs
new file mode 100644
index 0000000..c8b0b73
--- /dev/null
+++ b/jabber/src/client.rs
@@ -0,0 +1,217 @@
+use std::{pin::pin, sync::Arc, task::Poll};
+
+use futures::{Sink, Stream, StreamExt};
+use jid::ParseError;
+use rsasl::config::SASLConfig;
+use stanza::{
+ client::Stanza,
+ sasl::Mechanisms,
+ stream::{Feature, Features},
+};
+
+use crate::{
+ connection::{Tls, Unencrypted},
+ Connection, Error, JabberStream, Result, JID,
+};
+
+// feed it client stanzas, receive client stanzas
+pub struct JabberClient {
+ connection: ConnectionState,
+ jid: JID,
+ password: Arc<SASLConfig>,
+ server: String,
+}
+
+impl JabberClient {
+ pub fn new(
+ jid: impl TryInto<JID, Error = ParseError>,
+ password: impl ToString,
+ ) -> Result<JabberClient> {
+ let jid = jid.try_into()?;
+ let sasl_config = SASLConfig::with_credentials(
+ None,
+ jid.localpart.clone().ok_or(Error::NoLocalpart)?,
+ password.to_string(),
+ )?;
+ Ok(JabberClient {
+ connection: ConnectionState::Disconnected,
+ jid: jid.clone(),
+ password: sasl_config,
+ server: jid.domainpart,
+ })
+ }
+
+ pub async fn connect(&mut self) -> Result<()> {
+ match &self.connection {
+ ConnectionState::Disconnected => {
+ // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
+ // perhaps use take_mut?
+ self.connection = ConnectionState::Disconnected
+ .connect(&mut self.jid, self.password.clone(), &mut self.server)
+ .await?;
+ Ok(())
+ }
+ ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
+ ConnectionState::Connected(_jabber_stream) => Ok(()),
+ }
+ }
+
+ pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
+ match &mut self.connection {
+ ConnectionState::Disconnected => return Err(Error::Disconnected),
+ ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
+ ConnectionState::Connected(jabber_stream) => {
+ Ok(jabber_stream.send_stanza(stanza).await?)
+ }
+ }
+ }
+}
+
+impl Stream for JabberClient {
+ type Item = Result<Stanza>;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ let mut client = pin!(self);
+ match &mut client.connection {
+ ConnectionState::Disconnected => Poll::Pending,
+ ConnectionState::Connecting(_connecting) => Poll::Pending,
+ ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
+ }
+ }
+}
+
+pub enum ConnectionState {
+ Disconnected,
+ Connecting(Connecting),
+ Connected(JabberStream<Tls>),
+}
+
+impl ConnectionState {
+ pub async fn connect(
+ mut self,
+ jid: &mut JID,
+ auth: Arc<SASLConfig>,
+ server: &mut String,
+ ) -> Result<Self> {
+ loop {
+ match self {
+ ConnectionState::Disconnected => {
+ self = ConnectionState::Connecting(Connecting::start(&server).await?);
+ }
+ ConnectionState::Connecting(connecting) => match connecting {
+ Connecting::InsecureConnectionEstablised(tcp_stream) => {
+ self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
+ JabberStream::start_stream(tcp_stream, server).await?,
+ ))
+ }
+ Connecting::InsecureStreamStarted(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
+ jabber_stream.get_features().await?,
+ ))
+ }
+ Connecting::InsecureGotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => {
+ self =
+ ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
+ }
+ // TODO: better error
+ _ => return Err(Error::TlsRequired),
+ }
+ }
+ Connecting::StartTls(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
+ jabber_stream.starttls(&server).await?,
+ ))
+ }
+ Connecting::ConnectionEstablished(tls_stream) => {
+ self = ConnectionState::Connecting(Connecting::StreamStarted(
+ JabberStream::start_stream(tls_stream, server).await?,
+ ))
+ }
+ Connecting::StreamStarted(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::GotFeatures(
+ jabber_stream.get_features().await?,
+ ))
+ }
+ Connecting::GotFeatures((features, jabber_stream)) => {
+ match features.negotiate().ok_or(Error::Negotiation)? {
+ Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
+ Feature::Sasl(mechanisms) => {
+ self = ConnectionState::Connecting(Connecting::Sasl(
+ mechanisms,
+ jabber_stream,
+ ))
+ }
+ Feature::Bind => {
+ self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
+ }
+ Feature::Unknown => return Err(Error::Unsupported),
+ }
+ }
+ Connecting::Sasl(mechanisms, jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
+ jabber_stream.sasl(mechanisms, auth.clone()).await?,
+ ))
+ }
+ Connecting::Bind(jabber_stream) => {
+ self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
+ }
+ },
+ connected => return Ok(connected),
+ }
+ }
+ }
+}
+
+pub enum Connecting {
+ InsecureConnectionEstablised(Unencrypted),
+ InsecureStreamStarted(JabberStream<Unencrypted>),
+ InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
+ StartTls(JabberStream<Unencrypted>),
+ ConnectionEstablished(Tls),
+ StreamStarted(JabberStream<Tls>),
+ GotFeatures((Features, JabberStream<Tls>)),
+ Sasl(Mechanisms, JabberStream<Tls>),
+ Bind(JabberStream<Tls>),
+}
+
+impl Connecting {
+ pub async fn start(server: &str) -> Result<Self> {
+ match Connection::connect(server).await? {
+ Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
+ Connection::Unencrypted(tcp_stream) => {
+ Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
+ }
+ }
+ }
+}
+
+pub enum InsecureConnecting {
+ Disconnected,
+ ConnectionEstablished(Connection),
+ PreStarttls(JabberStream<Unencrypted>),
+ PreAuthenticated(JabberStream<Tls>),
+ Authenticated(Tls),
+ PreBound(JabberStream<Tls>),
+ Bound(JabberStream<Tls>),
+}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use super::JabberClient;
+ use test_log::test;
+ use tokio::time::sleep;
+
+ #[test(tokio::test)]
+ async fn login() {
+ let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
+ client.connect().await.unwrap();
+ sleep(Duration::from_secs(5)).await
+ }
+}
diff --git a/jabber/src/connection.rs b/jabber/src/connection.rs
new file mode 100644
index 0000000..bc5a282
--- /dev/null
+++ b/jabber/src/connection.rs
@@ -0,0 +1,184 @@
+use std::net::{IpAddr, SocketAddr};
+use std::str;
+use std::str::FromStr;
+use std::sync::Arc;
+
+use rsasl::config::SASLConfig;
+use tokio::net::TcpStream;
+use tokio_native_tls::native_tls::TlsConnector;
+// TODO: use rustls
+use tokio_native_tls::TlsStream;
+use tracing::{debug, info, instrument, trace};
+
+use crate::Result;
+use crate::{Error, JID};
+
+pub type Tls = TlsStream<TcpStream>;
+pub type Unencrypted = TcpStream;
+
+#[derive(Debug)]
+pub enum Connection {
+ Encrypted(Tls),
+ Unencrypted(Unencrypted),
+}
+
+impl Connection {
+ // #[instrument]
+ /// stream not started
+ // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> {
+ // match self {
+ // Connection::Encrypted(j) => Ok(j),
+ // Connection::Unencrypted(mut j) => {
+ // j.start_stream().await?;
+ // info!("upgrading connection to tls");
+ // j.get_features().await?;
+ // let j = j.starttls().await?;
+ // Ok(j)
+ // }
+ // }
+ // }
+
+ pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> {
+ let jid: JID = JID::from_str(jid.as_ref())?;
+ let server = jid.domainpart.clone();
+ Self::connect(&server).await
+ }
+
+ #[instrument]
+ pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> {
+ info!("connecting to {}", server.as_ref());
+ let sockets = Self::get_sockets(server.as_ref()).await;
+ debug!("discovered sockets: {:?}", sockets);
+ for (socket_addr, tls) in sockets {
+ match tls {
+ true => {
+ if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await {
+ info!("connected via encrypted stream to {}", socket_addr);
+ // let (readhalf, writehalf) = tokio::io::split(connection);
+ return Ok(Self::Encrypted(connection));
+ }
+ }
+ false => {
+ if let Ok(connection) = Self::connect_unencrypted(socket_addr).await {
+ info!("connected via unencrypted stream to {}", socket_addr);
+ // let (readhalf, writehalf) = tokio::io::split(connection);
+ return Ok(Self::Unencrypted(connection));
+ }
+ }
+ }
+ }
+ Err(Error::Connection)
+ }
+
+ #[instrument]
+ async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> {
+ let mut socket_addrs = Vec::new();
+
+ // if it's a socket/ip then just return that
+
+ // socket
+ trace!("checking if address is a socket address");
+ if let Ok(socket_addr) = SocketAddr::from_str(address) {
+ debug!("{} is a socket address", address);
+ match socket_addr.port() {
+ 5223 => socket_addrs.push((socket_addr, true)),
+ _ => socket_addrs.push((socket_addr, false)),
+ }
+
+ return socket_addrs;
+ }
+ // ip
+ trace!("checking if address is an ip");
+ if let Ok(ip) = IpAddr::from_str(address) {
+ debug!("{} is an ip", address);
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ return socket_addrs;
+ }
+
+ // otherwise resolve
+ debug!("resolving {}", address);
+ if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() {
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpp-client._tcp.{}", address))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), false))
+ }
+ });
+ }
+ }
+ if let Ok(lookup) = resolver
+ .srv_lookup(format!("_xmpps-client._tcp.{}", address))
+ .await
+ {
+ for srv in lookup {
+ resolver
+ .lookup_ip(srv.target().to_owned())
+ .await
+ .map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, srv.port()), true))
+ }
+ });
+ }
+ }
+
+ // in case cannot connect through SRV records
+ resolver.lookup_ip(address).await.map(|ips| {
+ for ip in ips {
+ socket_addrs.push((SocketAddr::new(ip, 5222), false));
+ socket_addrs.push((SocketAddr::new(ip, 5223), true));
+ }
+ });
+ }
+ socket_addrs
+ }
+
+ /// establishes a connection to the server
+ #[instrument]
+ pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> {
+ let socket = TcpStream::connect(socket_addr)
+ .await
+ .map_err(|_| Error::Connection)?;
+ let connector = TlsConnector::new().map_err(|_| Error::Connection)?;
+ tokio_native_tls::TlsConnector::from(connector)
+ .connect(domain_name, socket)
+ .await
+ .map_err(|_| Error::Connection)
+ }
+
+ #[instrument]
+ pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> {
+ TcpStream::connect(socket_addr)
+ .await
+ .map_err(|_| Error::Connection)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use test_log::test;
+
+ #[test(tokio::test)]
+ async fn connect() {
+ Connection::connect("blos.sm").await.unwrap();
+ }
+
+ // #[test(tokio::test)]
+ // async fn test_tls() {
+ // Connection::connect("blos.sm", None, None)
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap();
+ // }
+}
diff --git a/jabber/src/error.rs b/jabber/src/error.rs
new file mode 100644
index 0000000..aad033c
--- /dev/null
+++ b/jabber/src/error.rs
@@ -0,0 +1,78 @@
+use std::str::Utf8Error;
+
+use jid::ParseError;
+use rsasl::mechname::MechanismNameError;
+use stanza::client::error::Error as ClientError;
+use stanza::sasl::Failure;
+use stanza::stream::Error as StreamError;
+
+#[derive(Debug)]
+pub enum Error {
+ Connection,
+ Utf8Decode,
+ Negotiation,
+ TlsRequired,
+ AlreadyTls,
+ Unsupported,
+ NoLocalpart,
+ AlreadyConnecting,
+ UnexpectedElement(peanuts::Element),
+ XML(peanuts::Error),
+ Deserialization(peanuts::DeserializeError),
+ SASL(SASLError),
+ JID(ParseError),
+ Authentication(Failure),
+ ClientError(ClientError),
+ StreamError(StreamError),
+ MissingError,
+ Disconnected,
+ Connecting,
+}
+
+#[derive(Debug)]
+pub enum SASLError {
+ SASL(rsasl::prelude::SASLError),
+ MechanismName(MechanismNameError),
+}
+
+impl From<rsasl::prelude::SASLError> for Error {
+ fn from(e: rsasl::prelude::SASLError) -> Self {
+ Self::SASL(SASLError::SASL(e))
+ }
+}
+
+impl From<peanuts::DeserializeError> for Error {
+ fn from(e: peanuts::DeserializeError) -> Self {
+ Error::Deserialization(e)
+ }
+}
+
+impl From<MechanismNameError> for Error {
+ fn from(e: MechanismNameError) -> Self {
+ Self::SASL(SASLError::MechanismName(e))
+ }
+}
+
+impl From<SASLError> for Error {
+ fn from(e: SASLError) -> Self {
+ Self::SASL(e)
+ }
+}
+
+impl From<Utf8Error> for Error {
+ fn from(_e: Utf8Error) -> Self {
+ Self::Utf8Decode
+ }
+}
+
+impl From<peanuts::Error> for Error {
+ fn from(e: peanuts::Error) -> Self {
+ Self::XML(e)
+ }
+}
+
+impl From<ParseError> for Error {
+ fn from(e: ParseError) -> Self {
+ Self::JID(e)
+ }
+}
diff --git a/jabber/src/jabber_stream.rs b/jabber/src/jabber_stream.rs
new file mode 100644
index 0000000..dd0dcbf
--- /dev/null
+++ b/jabber/src/jabber_stream.rs
@@ -0,0 +1,393 @@
+use std::pin::pin;
+use std::str::{self, FromStr};
+use std::sync::Arc;
+
+use futures::StreamExt;
+use jid::JID;
+use peanuts::element::{FromContent, IntoElement};
+use peanuts::{Reader, Writer};
+use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
+use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
+use stanza::client::iq::{Iq, IqType, Query};
+use stanza::client::Stanza;
+use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
+use stanza::starttls::{Proceed, StartTls};
+use stanza::stream::{Features, Stream};
+use stanza::XML_VERSION;
+use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
+use tokio_native_tls::native_tls::TlsConnector;
+use tracing::{debug, instrument};
+
+use crate::connection::{Tls, Unencrypted};
+use crate::error::Error;
+use crate::Result;
+
+// open stream (streams started)
+pub struct JabberStream<S> {
+ reader: Reader<ReadHalf<S>>,
+ writer: Writer<WriteHalf<S>>,
+}
+
+impl<S: AsyncRead> futures::Stream for JabberStream<S> {
+ type Item = Result<Stanza>;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ pin!(self).reader.poll_next_unpin(cx).map(|content| {
+ content.map(|content| -> Result<Stanza> {
+ let stanza = content.map(|content| Stanza::from_content(content))?;
+ Ok(stanza?)
+ })
+ })
+ }
+}
+
+impl<S> JabberStream<S>
+where
+ S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
+ JabberStream<S>: std::fmt::Debug,
+{
+ #[instrument]
+ pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
+ let sasl = SASLClient::new(sasl_config);
+ let mut offered_mechs: Vec<&Mechname> = Vec::new();
+ for mechanism in &mechanisms.mechanisms {
+ offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
+ }
+ debug!("{:?}", offered_mechs);
+ let mut session = sasl.start_suggested(&offered_mechs)?;
+ let selected_mechanism = session.get_mechname().as_str().to_owned();
+ debug!("selected mech: {:?}", selected_mechanism);
+ let mut data: Option<Vec<u8>> = None;
+
+ if !session.are_we_first() {
+ // if not first mention the mechanism then get challenge data
+ // mention mechanism
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: "=".to_string(),
+ };
+ self.writer.write_full(&auth).await?;
+ // get challenge data
+ let challenge: Challenge = self.reader.read().await?;
+ debug!("challenge: {:?}", challenge);
+ data = Some((*challenge).as_bytes().to_vec());
+ debug!("we didn't go first");
+ } else {
+ // if first, mention mechanism and send data
+ let mut sasl_data = Vec::new();
+ session.step64(None, &mut sasl_data).unwrap();
+ let auth = Auth {
+ mechanism: selected_mechanism,
+ sasl_data: str::from_utf8(&sasl_data)?.to_string(),
+ };
+ debug!("{:?}", auth);
+ self.writer.write_full(&auth).await?;
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
+ }
+ debug!("we went first");
+ }
+
+ // stepping the authentication exchange to completion
+ if data != None {
+ debug!("data: {:?}", data);
+ let mut sasl_data = Vec::new();
+ while {
+ // decide if need to send more data over
+ let state = session
+ .step64(data.as_deref(), &mut sasl_data)
+ .expect("step errored!");
+ state.is_running()
+ } {
+ // While we aren't finished, receive more data from the other party
+ let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
+ debug!("response: {:?}", response);
+ let stdout = tokio::io::stdout();
+ let mut writer = Writer::new(stdout);
+ writer.write_full(&response).await?;
+ self.writer.write_full(&response).await?;
+ debug!("response written");
+
+ let server_response: ServerResponse = self.reader.read().await?;
+ debug!("server_response: {:#?}", server_response);
+ match server_response {
+ ServerResponse::Challenge(challenge) => {
+ data = Some((*challenge).as_bytes().to_vec())
+ }
+ ServerResponse::Success(success) => {
+ data = success.clone().map(|success| success.as_bytes().to_vec())
+ }
+ ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
+ }
+ }
+ }
+ let writer = self.writer.into_inner();
+ let reader = self.reader.into_inner();
+ let stream = reader.unsplit(writer);
+ Ok(stream)
+ }
+
+ #[instrument]
+ pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
+ let iq_id = nanoid::nanoid!();
+ if let Some(resource) = &jid.resourcepart {
+ let iq = Iq {
+ from: None,
+ id: iq_id.clone(),
+ to: None,
+ r#type: IqType::Set,
+ lang: None,
+ query: Some(Query::Bind(Bind {
+ r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
+ })),
+ errors: Vec::new(),
+ };
+ self.writer.write_full(&iq).await?;
+ let result: Iq = self.reader.read().await?;
+ match result {
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Result,
+ lang: _,
+ query:
+ Some(Query::Bind(Bind {
+ r#type: Some(BindType::Jid(FullJidType(new_jid))),
+ })),
+ errors: _,
+ } if id == iq_id => {
+ *jid = new_jid;
+ return Ok(self);
+ }
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Error,
+ lang: _,
+ query: None,
+ errors,
+ } if id == iq_id => {
+ return Err(Error::ClientError(
+ errors.first().ok_or(Error::MissingError)?.clone(),
+ ))
+ }
+ _ => return Err(Error::UnexpectedElement(result.into_element())),
+ }
+ } else {
+ let iq = Iq {
+ from: None,
+ id: iq_id.clone(),
+ to: None,
+ r#type: IqType::Set,
+ lang: None,
+ query: Some(Query::Bind(Bind { r#type: None })),
+ errors: Vec::new(),
+ };
+ self.writer.write_full(&iq).await?;
+ let result: Iq = self.reader.read().await?;
+ match result {
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Result,
+ lang: _,
+ query:
+ Some(Query::Bind(Bind {
+ r#type: Some(BindType::Jid(FullJidType(new_jid))),
+ })),
+ errors: _,
+ } if id == iq_id => {
+ *jid = new_jid;
+ return Ok(self);
+ }
+ Iq {
+ from: _,
+ id,
+ to: _,
+ r#type: IqType::Error,
+ lang: _,
+ query: None,
+ errors,
+ } if id == iq_id => {
+ return Err(Error::ClientError(
+ errors.first().ok_or(Error::MissingError)?.clone(),
+ ))
+ }
+ _ => return Err(Error::UnexpectedElement(result.into_element())),
+ }
+ }
+ }
+
+ #[instrument]
+ pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
+ // client to server
+ let (reader, writer) = tokio::io::split(connection);
+ let mut reader = Reader::new(reader);
+ let mut writer = Writer::new(writer);
+
+ // declaration
+ writer.write_declaration(XML_VERSION).await?;
+
+ // opening stream element
+ let stream = Stream::new_client(
+ None,
+ JID::from_str(server.as_ref())?,
+ None,
+ "en".to_string(),
+ );
+ writer.write_start(&stream).await?;
+
+ // server to client
+
+ // may or may not send a declaration
+ let _decl = reader.read_prolog().await?;
+
+ // receive stream element and validate
+ let stream: Stream = reader.read_start().await?;
+ debug!("got stream: {:?}", stream);
+ if let Some(from) = stream.from {
+ *server = from.to_string();
+ }
+
+ Ok(Self { reader, writer })
+ }
+
+ #[instrument]
+ pub async fn get_features(mut self) -> Result<(Features, Self)> {
+ debug!("getting features");
+ let features: Features = self.reader.read().await?;
+ debug!("got features: {:?}", features);
+ Ok((features, self))
+ }
+
+ pub fn into_inner(self) -> S {
+ self.reader.into_inner().unsplit(self.writer.into_inner())
+ }
+
+ pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
+ self.writer.write(stanza).await?;
+ Ok(())
+ }
+}
+
+impl JabberStream<Unencrypted> {
+ #[instrument]
+ pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
+ self.writer
+ .write_full(&StartTls { required: false })
+ .await?;
+ let proceed: Proceed = self.reader.read().await?;
+ debug!("got proceed: {:?}", proceed);
+ let connector = TlsConnector::new().unwrap();
+ let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
+ if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
+ .connect(domain.as_ref(), stream)
+ .await
+ {
+ return Ok(tls_stream);
+ } else {
+ return Err(Error::Connection);
+ }
+ }
+}
+
+impl std::fmt::Debug for JabberStream<Tls> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Jabber")
+ .field("connection", &"tls")
+ .finish()
+ }
+}
+
+impl std::fmt::Debug for JabberStream<Unencrypted> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Jabber")
+ .field("connection", &"unencrypted")
+ .finish()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use super::*;
+ use crate::connection::Connection;
+ use test_log::test;
+ use tokio::time::sleep;
+
+ #[test(tokio::test)]
+ async fn start_stream() {
+ // let connection = Connection::connect("blos.sm", None, None).await.unwrap();
+ // match connection {
+ // Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
+ // Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
+ // }
+ }
+
+ #[test(tokio::test)]
+ async fn sasl() {
+ // let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap();
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+ // jabber.start_stream().await.unwrap();
+
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+ // jabber.reader.read_buf().await.unwrap();
+ // let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
+ // println!("data: {}", text);
+
+ // let features = jabber.get_features().await.unwrap();
+ // let (sasl_config, feature) = (
+ // jabber.auth.clone().unwrap(),
+ // features
+ // .features
+ // .iter()
+ // .find(|feature| matches!(feature, Feature::Sasl(_)))
+ // .unwrap(),
+ // );
+ // match feature {
+ // Feature::StartTls(_start_tls) => todo!(),
+ // Feature::Sasl(mechanisms) => {
+ // jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
+ // }
+ // Feature::Bind => todo!(),
+ // Feature::Unknown => todo!(),
+ // }
+ }
+
+ #[tokio::test]
+ async fn negotiate() {
+ // let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
+ // .await
+ // .unwrap()
+ // .ensure_tls()
+ // .await
+ // .unwrap()
+ // .negotiate()
+ // .await
+ // .unwrap();
+ // sleep(Duration::from_secs(5)).await
+ }
+}
diff --git a/jabber/src/lib.rs b/jabber/src/lib.rs
new file mode 100644
index 0000000..bcd63db
--- /dev/null
+++ b/jabber/src/lib.rs
@@ -0,0 +1,34 @@
+#![allow(unused_must_use)]
+// #![feature(let_chains)]
+
+// TODO: logging (dropped errors)
+pub mod client;
+pub mod connection;
+pub mod error;
+pub mod jabber_stream;
+
+pub use connection::Connection;
+use connection::Tls;
+pub use error::Error;
+pub use jabber_stream::JabberStream;
+pub use jid::JID;
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<JabberStream<Tls>> {
+ todo!()
+ // Ok(Connection::connect_user(jid, password.as_ref().to_string())
+ // .await?
+ // .ensure_tls()
+ // .await?
+ // .negotiate()
+ // .await?)
+}
+
+#[cfg(test)]
+mod tests {
+ // #[tokio::test]
+ // async fn test_login() {
+ // crate::login("test@blos.sm/clown", "slayed").await.unwrap();
+ // }
+}