use std::str::{self, FromStr};
use std::sync::Arc;
use jid::JID;
use peanuts::element::IntoElement;
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use stanza::client::iq::{Iq, IqType, Query};
use stanza::client::Stanza;
use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use stanza::starttls::{Proceed, StartTls};
use stanza::stream::{Features, Stream};
use stanza::XML_VERSION;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio_native_tls::native_tls::TlsConnector;
use tracing::{debug, instrument};
use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::Result;
pub mod bound_stream;
// open stream (streams started)
pub struct JabberStream<S> {
reader: JabberReader<S>,
writer: JabberWriter<S>,
}
impl<S> JabberStream<S> {
fn split(self) -> (JabberReader<S>, JabberWriter<S>) {
let reader = self.reader;
let writer = self.writer;
(reader, writer)
}
}
pub struct JabberReader<S>(Reader<ReadHalf<S>>);
impl<S> JabberReader<S> {
// TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner
fn new(reader: Reader<ReadHalf<S>>) -> Self {
Self(reader)
}
fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> {
JabberStream {
reader: self,
writer,
}
}
fn into_inner(self) -> Reader<ReadHalf<S>> {
self.0
}
}
impl<S> JabberReader<S>
where
S: AsyncRead + Unpin,
{
pub async fn try_close(&mut self) -> Result<()> {
self.read_end_tag().await?;
Ok(())
}
}
impl<S> std::ops::Deref for JabberReader<S> {
type Target = Reader<ReadHalf<S>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> std::ops::DerefMut for JabberReader<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct JabberWriter<S>(Writer<WriteHalf<S>>);
impl<S> JabberWriter<S> {
fn new(writer: Writer<WriteHalf<S>>) -> Self {
Self(writer)
}
fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> {
JabberStream {
reader,
writer: self,
}
}
fn into_inner(self) -> Writer<WriteHalf<S>> {
self.0
}
}
impl<S> JabberWriter<S>
where
S: AsyncWrite + Unpin + Send,
{
pub async fn try_close(&mut self) -> Result<()> {
self.write_end().await?;
Ok(())
}
}
impl<S> std::ops::Deref for JabberWriter<S> {
type Target = Writer<WriteHalf<S>>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S> std::ops::DerefMut for JabberWriter<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<S> JabberStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
JabberStream<S>: std::fmt::Debug,
{
#[instrument]
pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> {
let sasl = SASLClient::new(sasl_config);
let mut offered_mechs: Vec<&Mechname> = Vec::new();
for mechanism in &mechanisms.mechanisms {
offered_mechs
.push(Mechname::parse(mechanism.as_bytes()).map_err(|e| Error::SASL(e.into()))?)
}
debug!("{:?}", offered_mechs);
let mut session = sasl
.start_suggested(&offered_mechs)
.map_err(|e| Error::SASL(e.into()))?;
let selected_mechanism = session.get_mechname().as_str().to_owned();
debug!("selected mech: {:?}", selected_mechanism);
let mut data: Option<Vec<u8>>;
if !session.are_we_first() {
// if not first mention the mechanism then get challenge data
// mention mechanism
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: "=".to_string(),
};
self.writer.write_full(&auth).await?;
// get challenge data
let challenge: Challenge = self.reader.read().await?;
debug!("challenge: {:?}", challenge);
data = Some((*challenge).as_bytes().to_vec());
debug!("we didn't go first");
} else {
// if first, mention mechanism and send data
let mut sasl_data = Vec::new();
session.step64(None, &mut sasl_data).unwrap();
let auth = Auth {
mechanism: selected_mechanism,
sasl_data: str::from_utf8(&sasl_data)?.to_string(),
};
debug!("{:?}", auth);
self.writer.write_full(&auth).await?;
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
}
debug!("we went first");
}
// stepping the authentication exchange to completion
if data != None {
debug!("data: {:?}", data);
let mut sasl_data = Vec::new();
while {
// decide if need to send more data over
let state = session
.step64(data.as_deref(), &mut sasl_data)
.expect("step errored!");
state.is_running()
} {
// While we aren't finished, receive more data from the other party
let response = Response::new(str::from_utf8(&sasl_data)?.to_string());
debug!("response: {:?}", response);
self.writer.write_full(&response).await?;
debug!("response written");
let server_response: ServerResponse = self.reader.read().await?;
debug!("server_response: {:#?}", server_response);
match server_response {
ServerResponse::Challenge(challenge) => {
data = Some((*challenge).as_bytes().to_vec())
}
ServerResponse::Success(success) => {
data = success.clone().map(|success| success.as_bytes().to_vec())
}
ServerResponse::Failure(failure) => return Err(Error::SASL(failure.into())),
}
}
}
let writer = self.writer.into_inner().into_inner();
let reader = self.reader.into_inner().into_inner();
let stream = reader.unsplit(writer);
Ok(stream)
}
#[instrument]
pub async fn bind(mut self, jid: &mut JID) -> Result<Self> {
let iq_id = nanoid::nanoid!();
if let Some(resource) = &jid.resourcepart {
let iq = Iq {
from: None,
id: iq_id.clone(),
to: None,
r#type: IqType::Set,
lang: None,
query: Some(Query::Bind(Bind {
r#type: Some(BindType::Resource(ResourceType(resource.to_string()))),
})),
errors: Vec::new(),
};
self.writer.write_full(&iq).await?;
let result: Iq = self.reader.read().await?;
match result {
Iq {
from: _,
id,
to: _,
r#type: IqType::Result,
lang: _,
query:
Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(new_jid))),
})),
errors: _,
} if id == iq_id => {
*jid = new_jid;
return Ok(self);
}
Iq {
from: _,
id,
to: _,
r#type: IqType::Error,
lang: _,
query: None,
errors,
} if id == iq_id => {
return Err(Error::ClientError(
errors.first().ok_or(Error::MissingError)?.clone(),
))
}
_ => return Err(Error::UnexpectedElement(result.into_element())),
}
} else {
let iq = Iq {
from: None,
id: iq_id.clone(),
to: None,
r#type: IqType::Set,
lang: None,
query: Some(Query::Bind(Bind { r#type: None })),
errors: Vec::new(),
};
self.writer.write_full(&iq).await?;
let result: Iq = self.reader.read().await?;
match result {
Iq {
from: _,
id,
to: _,
r#type: IqType::Result,
lang: _,
query:
Some(Query::Bind(Bind {
r#type: Some(BindType::Jid(FullJidType(new_jid))),
})),
errors: _,
} if id == iq_id => {
*jid = new_jid;
return Ok(self);
}
Iq {
from: _,
id,
to: _,
r#type: IqType::Error,
lang: _,
query: None,
errors,
} if id == iq_id => {
return Err(Error::ClientError(
errors.first().ok_or(Error::MissingError)?.clone(),
))
}
_ => return Err(Error::UnexpectedElement(result.into_element())),
}
}
}
#[instrument]
pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> {
// client to server
let (reader, writer) = tokio::io::split(connection);
let mut reader = JabberReader::new(Reader::new(reader));
let mut writer = JabberWriter::new(Writer::new(writer));
// declaration
writer.write_declaration(XML_VERSION).await?;
// opening stream element
let stream = Stream::new_client(
None,
JID::from_str(server.as_ref())?,
None,
"en".to_string(),
);
writer.write_start(&stream).await?;
// server to client
// may or may not send a declaration
let _decl = reader.read_prolog().await?;
// receive stream element and validate
let stream: Stream = reader.read_start().await?;
debug!("got stream: {:?}", stream);
if let Some(from) = stream.from {
*server = from.to_string();
}
Ok(Self { reader, writer })
}
#[instrument]
pub async fn get_features(mut self) -> Result<(Features, Self)> {
debug!("getting features");
let features: Features = self.reader.read().await?;
debug!("got features: {:?}", features);
Ok((features, self))
}
pub fn into_inner(self) -> S {
self.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner().into_inner())
}
pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
self.writer.write(stanza).await?;
Ok(())
}
}
impl JabberStream<Unencrypted> {
#[instrument]
pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> {
self.writer
.write_full(&StartTls { required: false })
.await?;
let proceed: Proceed = self.reader.read().await?;
debug!("got proceed: {:?}", proceed);
let connector = TlsConnector::new().unwrap();
let stream = self
.reader
.into_inner()
.into_inner()
.unsplit(self.writer.into_inner().into_inner());
if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector)
.connect(domain.as_ref(), stream)
.await
{
return Ok(tls_stream);
} else {
return Err(Error::Connection);
}
}
}
impl std::fmt::Debug for JabberStream<Tls> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"tls")
.finish()
}
}
impl std::fmt::Debug for JabberStream<Unencrypted> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Jabber")
.field("connection", &"unencrypted")
.finish()
}
}
#[cfg(test)]
mod tests {
use test_log::test;
#[test(tokio::test)]
async fn start_stream() {
// let connection = Connection::connect("blos.sm", None, None).await.unwrap();
// match connection {
// Connection::Encrypted(mut c) => c.start_stream().await.unwrap(),
// Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(),
// }
}
#[test(tokio::test)]
async fn sasl() {
// let mut jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// jabber.start_stream().await.unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// jabber.reader.read_buf().await.unwrap();
// let text = str::from_utf8(jabber.reader.buffer.data()).unwrap();
// println!("data: {}", text);
// let features = jabber.get_features().await.unwrap();
// let (sasl_config, feature) = (
// jabber.auth.clone().unwrap(),
// features
// .features
// .iter()
// .find(|feature| matches!(feature, Feature::Sasl(_)))
// .unwrap(),
// );
// match feature {
// Feature::StartTls(_start_tls) => todo!(),
// Feature::Sasl(mechanisms) => {
// jabber.sasl(mechanisms.clone(), sasl_config).await.unwrap();
// }
// Feature::Bind => todo!(),
// Feature::Unknown => todo!(),
// }
}
#[tokio::test]
async fn sink() {
// let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
// client.connect().await.unwrap();
// let stream = client.inner().unwrap();
// let sink = sink::unfold(stream, |mut stream, stanza: Stanza| async move {
// stream.writer.write(&stanza).await?;
// Ok::<JabberStream<Tls>, Error>(stream)
// });
// todo!()
// let _jabber = Connection::connect_user("test@blos.sm", "slayed".to_string())
// .await
// .unwrap()
// .ensure_tls()
// .await
// .unwrap()
// .negotiate()
// .await
// .unwrap();
// sleep(Duration::from_secs(5)).await
}
}