use std::{
borrow::Borrow,
future::Future,
pin::pin,
sync::Arc,
task::{ready, Poll},
};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
client::Stanza,
sasl::Mechanisms,
stream::{Feature, Features},
};
use tokio::sync::Mutex;
use crate::{
connection::{Tls, Unencrypted},
jabber_stream::bound_stream::BoundJabberStream,
Connection, Error, JabberStream, Result, JID,
};
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
connection: ConnectionState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
impl JabberClient {
pub fn new(
jid: impl TryInto<JID, Error = ParseError>,
password: impl ToString,
) -> Result<JabberClient> {
let jid = jid.try_into()?;
let sasl_config = SASLConfig::with_credentials(
None,
jid.localpart.clone().ok_or(Error::NoLocalpart)?,
password.to_string(),
)?;
Ok(JabberClient {
connection: ConnectionState::Disconnected,
jid: jid.clone(),
password: sasl_config,
server: jid.domainpart,
})
}
pub async fn connect(&mut self) -> Result<()> {
match &self.connection {
ConnectionState::Disconnected => {
// TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
// perhaps use take_mut?
self.connection = ConnectionState::Disconnected
.connect(&mut self.jid, self.password.clone(), &mut self.server)
.await?;
Ok(())
}
ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
ConnectionState::Connected(_jabber_stream) => Ok(()),
}
}
pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
match self.connection {
ConnectionState::Disconnected => return Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
}
}
// pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
// match &mut self.connection {
// ConnectionState::Disconnected => return Err(Error::Disconnected),
// ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
// ConnectionState::Connected(jabber_stream) => {
// Ok(jabber_stream.send_stanza(stanza).await?)
// }
// }
// }
}
impl Sink<Stanza> for JabberClient {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_ready_unpin(cx)
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
self.get_mut().connection.start_send_unpin(item)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_flush_unpin(cx)
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
self.get_mut().connection.poll_flush_unpin(cx)
}
}
impl Stream for JabberClient {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.get_mut().connection.poll_next_unpin(cx)
}
}
pub enum ConnectionState {
Disconnected,
Connecting(Connecting),
Connected(BoundJabberStream<Tls>),
}
impl Sink<Stanza> for ConnectionState {
type Error = Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_ready_unpin(cx)
}
}
}
fn start_send(
self: std::pin::Pin<&mut Self>,
item: Stanza,
) -> std::result::Result<(), Self::Error> {
match self.get_mut() {
ConnectionState::Disconnected => Err(Error::Disconnected),
ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.start_send_unpin(item)
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_flush_unpin(cx)
}
}
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::result::Result<(), Self::Error>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_close_unpin(cx)
}
}
}
}
impl Stream for ConnectionState {
type Item = Result<Stanza>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.get_mut() {
ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
ConnectionState::Connecting(_connecting) => Poll::Pending,
ConnectionState::Connected(bound_jabber_stream) => {
bound_jabber_stream.poll_next_unpin(cx)
}
}
}
}
impl ConnectionState {
pub async fn connect(
mut self,
jid: &mut JID,
auth: Arc<SASLConfig>,
server: &mut String,
) -> Result<Self> {
loop {
match self {
ConnectionState::Disconnected => {
self = ConnectionState::Connecting(Connecting::start(&server).await?);
}
ConnectionState::Connecting(connecting) => match connecting {
Connecting::InsecureConnectionEstablised(tcp_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
JabberStream::start_stream(tcp_stream, server).await?,
))
}
Connecting::InsecureStreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::InsecureGotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => {
self =
ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
}
// TODO: better error
_ => return Err(Error::TlsRequired),
}
}
Connecting::StartTls(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.starttls(&server).await?,
))
}
Connecting::ConnectionEstablished(tls_stream) => {
self = ConnectionState::Connecting(Connecting::StreamStarted(
JabberStream::start_stream(tls_stream, server).await?,
))
}
Connecting::StreamStarted(jabber_stream) => {
self = ConnectionState::Connecting(Connecting::GotFeatures(
jabber_stream.get_features().await?,
))
}
Connecting::GotFeatures((features, jabber_stream)) => {
match features.negotiate().ok_or(Error::Negotiation)? {
Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
Feature::Sasl(mechanisms) => {
self = ConnectionState::Connecting(Connecting::Sasl(
mechanisms,
jabber_stream,
))
}
Feature::Bind => {
self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
}
Feature::Unknown => return Err(Error::Unsupported),
}
}
Connecting::Sasl(mechanisms, jabber_stream) => {
self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
jabber_stream.sasl(mechanisms, auth.clone()).await?,
))
}
Connecting::Bind(jabber_stream) => {
self = ConnectionState::Connected(
jabber_stream.bind(jid).await?.to_bound_jabber(),
)
}
},
connected => return Ok(connected),
}
}
}
}
pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
StartTls(JabberStream<Unencrypted>),
ConnectionEstablished(Tls),
StreamStarted(JabberStream<Tls>),
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
}
impl Connecting {
pub async fn start(server: &str) -> Result<Self> {
match Connection::connect(server).await? {
Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
Connection::Unencrypted(tcp_stream) => {
Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
}
}
}
pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
PreAuthenticated(JabberStream<Tls>),
Authenticated(Tls),
PreBound(JabberStream<Tls>),
Bound(JabberStream<Tls>),
}
#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
use super::JabberClient;
use futures::{SinkExt, StreamExt};
use stanza::{
client::{
iq::{Iq, IqType, Query},
Stanza,
},
xep_0199::Ping,
};
use test_log::test;
use tokio::{sync::Mutex, time::sleep};
use tracing::info;
#[test(tokio::test)]
async fn login() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await
}
#[test(tokio::test)]
async fn ping_parallel() {
let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
client.connect().await.unwrap();
sleep(Duration::from_secs(5)).await;
let jid = client.jid.clone();
let server = client.server.clone();
let mut client = Arc::new(Mutex::new(client));
tokio::join!(
async {
let mut client = client.lock().await;
client
.send(Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s1".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await;
},
async {
let mut client = client.lock().await;
client
.send(Stanza::Iq(Iq {
from: Some(jid.clone()),
id: "c2s2".to_string(),
to: Some(server.clone().try_into().unwrap()),
r#type: IqType::Get,
lang: None,
query: Some(Query::Ping(Ping)),
errors: Vec::new(),
}))
.await;
},
async {
while let Some(stanza) = client.lock().await.next().await {
info!("{:#?}", stanza);
}
}
);
}
}