use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc;
use async_recursion::async_recursion;
use futures::StreamExt;
use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, instrument};
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use crate::stanza::client::iq::{Iq, IqType, Query};
use crate::stanza::client::Stanza;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream};
use crate::stanza::XML_VERSION;
use crate::JID;
use crate::{Connection, Result};
// open stream (streams started)
pub struct JabberStream<S> {
reader: Reader<ReadHalf<S>>,
writer: Writer<WriteHalf<S>>,
}
impl<S: AsyncRead> futures::Stream for JabberStream<S> {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
pin!(self).reader.poll_next_unpin(cx).map(|content| {
content.map(|content| -> Result<Stanza> {
let stanza = content.map(|content| Stanza::from_content(content))?;
Ok(stanza?)
})
})
}
}
impl<S> JabberStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
JabberStream<S>: std::fmt::Debug,
{
#[instrument]
pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs.push(Mechname::parse(mechanism.as_bytes())?)
}
debug!("{:?}", offered_mechs);
let mut session = sasl.start_suggested(&offered_mechs)?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>> = None;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: "=".to_string(),
};
self.writer.write_full(&auth).await?;
// get challenge data
let challenge: Challenge = self.reader.read().await?;
debug!("challenge: {:?}", challenge);
data = Some((*challenge).as_bytes().to_vec());
debug!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: str::from_utf8(&sasl_data)?.to_string(),
};
debug!("{:?}", auth);
self.writer.write_full(&auth).await?;
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
debug!("we went first");
}
// stepping the authentication exchange to completion
if data != None {
debug!("data: {:?}", data);
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response);
let stdout = tokio::io::stdout();
let mut writer = Writer::new(stdout);
writer.write_full(&response).await?;
self.writer.write_full(&response).await?;
debug!("response written");
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::Authentication(failure)),
}
}
}
let writer = self.writer.into_inner();
let reader = self.reader.into_inner();
let stream = reader.unsplit(writer);
Ok(stream)
}
#[instrument]
pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
let iq_id = nanoid::nanoid!();
if let Some(resource) = &jid.resourcepart {
let iq = Iq {
from: None,
id: iq_id.clone(),
to: None,
r#type: IqType::Set,
lang: None,
query: Some(Query::Bind(Bind {
r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
})),
errors: Vec::new(),
};
self.writer.write_full(&iq).await?;
let result: Iq = self.reader.read().await?;
match result {
Iq {
from: _,
id,
to: _,
r#type: IqType::Result,
lang: _,
query:
Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(new_jid))),
})),
errors: _,
} if id == iq_id => {
*jid = new_jid;
return Ok(self);
}
Iq {
from: _,
id,
to: _,
r#type: IqType::Error,
lang: _,
query: None,
errors,
} if id == iq_id => {
return Err(Error::ClientError(
errors.first().ok_or(Error::MissingError)?.clone(),
))
}
_ => return Err(Error::UnexpectedElement(result.into_element())),
}
} else {
let iq = Iq {
from: None,
id: iq_id.clone(),
to: None,
r#type: IqType::Set,
lang: None,
query: Some(Query::Bind(Bind { r#type: None })),
errors: Vec::new(),
};
self.writer.write_full(&iq).await?;
let result: Iq = self.reader.read().await?;
match result {
Iq {
from: _,
id,
to: _,
r#type: IqType::Result,
lang: _,
query:
Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(new_jid))),
})),
errors: _,
} if id == iq_id => {
*jid = new_jid;
return Ok(self);
}
Iq {
from: _,
id,
to: _,
r#type: IqType::Error,
lang: _,
query: None,
errors,
} if id == iq_id => {
return Err(Error::ClientError(
errors.first().ok_or(Error::MissingError)?.clone(),
))
}
_ => return Err(Error::UnexpectedElement(result.into_element())),
}
}
}
#[instrument]
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
// client to server
let (reader, writer) = tokio::io::split(connection);
let mut reader = Reader::new(reader);
let mut writer = Writer::new(writer);
// declaration
writer.write_declaration(XML_VERSION).await?;
// opening stream element
let stream = Stream::new_client(
None,
JID::from_str(server.as_ref())?,
None,
"en".to_string(),
);
writer.write_start(&stream).await?;
// server to client
// may or may not send a declaration
let _decl = reader.read_prolog().await?;
// receive stream element and validate
let stream: Stream = reader.read_start().await?;
debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
*server = from.to_string();
}
Ok(Self { reader, writer })
}
#[instrument]
pub async fn get_features(mut self) -> Result<(Features, Self)> {
debug!("getting features");
let features: Features = self.reader.read().await?;
debug!("got features: {:?}", features);
Ok((features, self))
}
pub fn into_inner(self) -> S {
self.reader.into_inner().unsplit(self.writer.into_inner())
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
self.writer.write(stanza).await?;
Ok(())
}
}
impl JabberStream<Unencrypted> {
#[instrument]
pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
self.writer
.write_full(&StartTls { required: false })
.await?;
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
let stream = self.reader.into_inner().unsplit(self.writer.into_inner());
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(domain.as_ref(), stream)
.await
{
return Ok(tls_stream);
} else {
return Err(Error::Connection);
}
}
}
impl std::fmt::Debug for JabberStream<Tls> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"tls")
.finish()
}
}
impl std::fmt::Debug for JabberStream<Unencrypted> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"unencrypted")
.finish()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
use crate::connection::Connection;
use test_log::test;
use tokio::time::sleep;
#[test(tokio::test)]
async fn start_stream() {
// let connection = Connection::connect("blos.sm", None, None).await.unwrap();
// match connection {
// Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
// Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
// }
}
#[test(tokio::test)]
async fn sasl() {
// let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// jabber.start_stream().await.unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// jabber.reader.read_buf().await.unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// let features = jabber.get_features().await.unwrap();
// let (sasl_config, feature) = (
// jabber.auth.clone().unwrap(),
// features
// .features
// .iter()
// .find(|feature| matches!(feature, Feature::Sasl(_)))
// .unwrap(),
// );
// match feature {
// Feature::StartTls(_start_tls) => todo!(),
// Feature::Sasl(mechanisms) => {
// jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
// }
// Feature::Bind => todo!(),
// Feature::Unknown => todo!(),
// }
}
#[tokio::test]
async fn negotiate() {
// let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap()
// .negotiate()
// .await
// .unwrap();
// sleep(Duration::from_secs(5)).await
}
}