use rsasl::config::SASLConfig;
use stanza::{
sasl::Mechanisms,
stream::{Feature, Features},
};
use crate::{
connection::{Tls, Unencrypted},
jabber_stream::bound_stream::BoundJabberStream,
Connection, Error, JabberStream, Result, JID,
};
pub async fn connect_and_login(
jid: &mut JID,
password: impl AsRef<str>,
server: &mut String,
) -> Result<BoundJabberStream<Tls>> {
let auth = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.as_ref().to_string(),
)
.map_err(|e| Error::SASL(e.into()))?;
let mut conn_state = Connecting::start(&server).await?;
loop {
match conn_state {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
conn_state = Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
)
}
Connecting::InsecureStreamStarted(jabber_stream) => {
conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => {
conn_state = Connecting::StartTls(jabber_stream)
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
conn_state =
Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
}
Connecting::ConnectionEstablished(tls_stream) => {
conn_state =
Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
}
Connecting::StreamStarted(jabber_stream) => {
conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
conn_state = Connecting::Sasl(mechanisms, jabber_stream)
}
Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
conn_state = Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
)
}
Connecting::Bind(jabber_stream) => {
return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
}
}
}
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use jid::JID;
use stanza::{
client::{
iq::{Iq, IqType, Query},
Stanza,
},
xep_0199::Ping,
};
use test_log::test;
use tokio::time::sleep;
use tracing::info;
use super::connect_and_login;
#[test(tokio::test)]
async fn login() {
let mut jid: JID = "test@blos.sm".try_into().unwrap();
let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string())
.await
.unwrap();
sleep(Duration::from_secs(5)).await
}
#[test(tokio::test)]
async fn ping_parallel() {
let mut jid: JID = "test@blos.sm".try_into().unwrap();
let mut server = "blos.sm".to_string();
let client = connect_and_login(&mut jid, "slayed", &mut server)
.await
.unwrap();
let (mut read, mut write) = client.split();
tokio::join!(
async {
write
.write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await
.unwrap();
write
.write(&Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await
.unwrap();
},
async {
for _ in 0..2 {
let stanza = read.read::<Stanza>().await.unwrap();
info!("ping reply: {:#?}", stanza);
}
}
);
}
}