use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
};
use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: Option<BoundJabberStream<Tls>>,
jid: JID,
// TODO: have reconnection be handled by another part, so creds don't need to be stored in object
password: Arc<SASLConfig>,
server: String,
}
impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: None,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub fn jid(&self) -> JID {
self.jid.clone()
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
Some(_) => Ok(()),
None => {
self.connection = Some(
connect_and_login(&mut self.jid, self.password.clone(), &mut self.server)
.await?,
);
Ok(())
}
}
}
pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> {
self.connection.ok_or(Error::Disconnected)
}
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
// match &mut self.connection {
// ConnectionState::Disconnected => return Err(Error::Disconnected),
// ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
// ConnectionState::Connected(jabber_stream) => {
// Ok(jabber_stream.send_stanza(stanza).await?)
// }
// }
// }
}
pub async fn connect_and_login(
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<BoundJabberStream<Tls>> {
let mut conn_state = Connecting::start(&server).await?;
loop {
match conn_state {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
conn_state = Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
)
}
Connecting::InsecureStreamStarted(jabber_stream) => {
conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => {
conn_state = Connecting::StartTls(jabber_stream)
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
conn_state =
Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
}
Connecting::ConnectionEstablished(tls_stream) => {
conn_state =
Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
}
Connecting::StreamStarted(jabber_stream) => {
conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
conn_state = Connecting::Sasl(mechanisms, jabber_stream)
}
Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
conn_state = Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
)
}
Connecting::Bind(jabber_stream) => {
return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
}
}
}
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
use super::JabberClient;
use futures::{SinkExt, StreamExt};
use stanza::{
client::{
iq::{Iq, IqType, Query},
Stanza,
},
xep_0199::Ping,
};
use test_log::test;
use tokio::{sync::Mutex, time::sleep};
use tracing::info;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
#[test(tokio::test)]
async fn ping_parallel() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await;
let jid = client.jid.clone();
let server = client.server.clone();
let (mut read, mut write) = client.into_inner().unwrap().split();
tokio::join!(
async {
write
.write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await
.unwrap();
write
.write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await
.unwrap();
},
async {
for _ in 0..2 {
let stanza = read.read::<Stanza>().await.unwrap();
info!("ping reply: {:#?}", stanza);
}
}
);
}
}