use std::str;
use std::sync::Arc;
use async_recursion::async_recursion;
use peanuts::element::{FromElement, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf};
use tokio::time::timeout;
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, info, instrument, trace};
use trust_dns_resolver::proto::rr::domain::IntoLabel;
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream};
use crate::stanza::XML_VERSION;
use crate::JID;
use crate::{Connection, Result};
pub struct Jabber<S> {
reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>,
jid: Option<JID>,
auth: Option<Arc<SASLConfig>>,
server: String,
}
impl<S> Jabber<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(
reader: ReadHalf<S>,
writer: WriteHalf<S>,
jid: Option<JID>,
auth: Option<Arc<SASLConfig>>,
server: String,
) -> Self {
let reader = Reader::new(reader);
let writer = Writer::new(writer);
Self {
reader,
writer,
jid,
auth,
server,
}
}
}
impl<S> Jabber<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
Jabber<S>: std::fmt::Debug,
{
pub async fn sasl(
&mut self,
mechanisms: Mechanisms,
sasl_config: Arc<SASLConfig>,
) -> Result<()> {
let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
debug!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: "=".to_string(),
};
self.writer.write_full(&auth).await?;
// get challenge data
let challenge: Challenge = self.reader.read().await?;
debug!("challenge: {:?}", challenge);
data = Some((*challenge).as_bytes().to_vec());
debug!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: str::from_utf8(&sasl_data)?.to_string(),
};
debug!("{:?}", auth);
self.writer.write_full(&auth).await?;
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
debug!("we went first");
}
// stepping the authentication exchange to completion
if data != None {
debug!("data: {:?}", data);
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response);
let stdout = tokio::io::stdout();
let mut writer = Writer::new(stdout);
writer.write_full(&response).await?;
self.writer.write_full(&response).await?;
debug!("response written");
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
}
}
Ok(())
}
pub async fn bind(&mut self) -> Result<()> {
todo!()
}
#[instrument]
pub async fn start_stream(&mut self) -> Result<()> {
// client to server
// declaration
self.writer.write_declaration(XML_VERSION).await?;
// opening stream element
let server = self.server.clone().try_into()?;
let stream = Stream::new_client(None, server, None, "en".to_string());
self.writer.write_start(&stream).await?;
// server to client
// may or may not send a declaration
let decl = self.reader.read_prolog().await?;
// receive stream element and validate
let text = str::from_utf8(self.reader.buffer.data()).unwrap();
debug!("data: {}", text);
let stream: Stream = self.reader.read_start().await?;
debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
self.server = from.to_string()
}
Ok(())
}
pub async fn get_features(&mut self) -> Result<Features> {
debug!("getting features");
let features: Features = self.reader.read().await?;
debug!("got features: {:?}", features);
Ok(features)
}
pub fn into_inner(self) -> S {
self.reader.into_inner().unsplit(self.writer.into_inner())
}
}
impl Jabber<Unencrypted> {
pub async fn negotiate<S: AsyncRead + AsyncWrite + Unpin>(mut self) -> Result<Jabber<Tls>> {
self.start_stream().await?;
// TODO: timeout
let features = self.get_features().await?.features;
if let Some(Feature::StartTls(_)) = features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
let jabber = self.starttls().await?;
let jabber = jabber.negotiate().await?;
return Ok(jabber);
} else {
// TODO: better error
return Err(Error::TlsRequired);
}
}
#[async_recursion]
pub async fn negotiate_tls_optional(mut self) -> Result<Connection> {
self.start_stream().await?;
// TODO: timeout
let features = self.get_features().await?.features;
if let Some(Feature::StartTls(_)) = features
.iter()
.find(|feature| matches!(feature, Feature::StartTls(_s)))
{
let jabber = self.starttls().await?;
let jabber = jabber.negotiate().await?;
return Ok(Connection::Encrypted(jabber));
} else if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
self.auth.clone(),
features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_))),
) {
self.sasl(mechanisms.clone(), sasl_config).await?;
let jabber = self.negotiate_tls_optional().await?;
Ok(jabber)
} else if let Some(Feature::Bind) = features
.iter()
.find(|feature| matches!(feature, Feature::Bind))
{
self.bind().await?;
Ok(Connection::Unencrypted(self))
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
impl Jabber<Tls> {
#[async_recursion]
pub async fn negotiate(mut self) -> Result<Jabber<Tls>> {
self.start_stream().await?;
let features = self.get_features().await?.features;
if let (Some(sasl_config), Some(Feature::Sasl(mechanisms))) = (
self.auth.clone(),
features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_))),
) {
// TODO: avoid clone
self.sasl(mechanisms.clone(), sasl_config).await?;
let jabber = self.negotiate().await?;
Ok(jabber)
} else if let Some(Feature::Bind) = features
.iter()
.find(|feature| matches!(feature, Feature::Bind))
{
self.bind().await?;
Ok(self)
} else {
// TODO: better error
return Err(Error::Negotiation);
}
}
}
impl Jabber<Unencrypted> {
pub async fn starttls(mut self) -> Result<Jabber<Tls>> {
self.writer
.write_full(&StartTls { required: false })
.await?;
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector)
.connect(&self.server, stream)
.await
{
let (read, write) = tokio::io::split(tlsstream);
let client = Jabber::new(
read,
write,
self.jid.to_owned(),
self.auth.to_owned(),
self.server.to_owned(),
);
return Ok(client);
} else {
return Err(Error::Connection);
}
}
}
impl std::fmt::Debug for Jabber<Tls> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"tls")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish()
}
}
impl std::fmt::Debug for Jabber<Unencrypted> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"unencrypted")
.field("jid", &self.jid)
.field("auth", &self.auth)
.field("server", &self.server)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::Connection;
use test_log::test;
#[test(tokio::test)]
async fn start_stream() {
let connection = Connection::connect("blos.sm", None, None).await.unwrap();
match connection {
Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
}
}
#[test(tokio::test)]
async fn sasl() {
let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
.await
.unwrap()
.ensure_tls()
.await
.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
jabber.start_stream().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
jabber.reader.read_buf().await.unwrap();
let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
println!("data: {}", text);
let features = jabber.get_features().await.unwrap();
let (sasl_config, feature) = (
jabber.auth.clone().unwrap(),
features
.features
.iter()
.find(|feature| matches!(feature, Feature::Sasl(_)))
.unwrap(),
);
match feature {
Feature::StartTls(_start_tls) => todo!(),
Feature::Sasl(mechanisms) => {
jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
}
Feature::Bind => todo!(),
Feature::Unknown => todo!(),
}
}
}