aboutsummaryrefslogblamecommitdiffstats
path: root/src/client.rs
blob: 5351b34d69a95787a643798fde4441b4007aa3ba (plain) (tree)
1
2
3
4
5
6
7
8
                                           
 
                                       



                                   
                    









                                                 
                                




                              

















































                                                                                           
                 


















































































                                                                                                    








                                                               

 





                                                                                                   
             
































                                                                     
                             








                                           






























                                                                















                                                                              
use std::{pin::pin, sync::Arc, task::Poll};

use futures::{Sink, Stream, StreamExt};
use rsasl::config::SASLConfig;

use crate::{
    connection::{Tls, Unencrypted},
    jid::ParseError,
    stanza::{
        client::Stanza,
        sasl::Mechanisms,
        stream::{Feature, Features},
    },
    Connection, Error, JabberStream, Result, JID,
};

// feed it client stanzas, receive client stanzas
pub struct JabberClient {
    connection: ConnectionState,
    jid: JID,
    password: Arc<SASLConfig>,
    server: String,
}

impl JabberClient {
    pub fn new(
        jid: impl TryInto<JID, Error = ParseError>,
        password: impl ToString,
    ) -> Result<JabberClient> {
        let jid = jid.try_into()?;
        let sasl_config = SASLConfig::with_credentials(
            None,
            jid.localpart.clone().ok_or(Error::NoLocalpart)?,
            password.to_string(),
        )?;
        Ok(JabberClient {
            connection: ConnectionState::Disconnected,
            jid: jid.clone(),
            password: sasl_config,
            server: jid.domainpart,
        })
    }

    pub async fn connect(&mut self) -> Result<()> {
        match &self.connection {
            ConnectionState::Disconnected => {
                self.connection = ConnectionState::Disconnected
                    .connect(&mut self.jid, self.password.clone(), &mut self.server)
                    .await?;
                Ok(())
            }
            ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
            ConnectionState::Connected(_jabber_stream) => Ok(()),
        }
    }
}

impl Stream for JabberClient {
    type Item = Result<Stanza>;

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        let mut client = pin!(self);
        match &mut client.connection {
            ConnectionState::Disconnected => Poll::Pending,
            ConnectionState::Connecting(_connecting) => Poll::Pending,
            ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
        }
    }
}

pub enum ConnectionState {
    Disconnected,
    Connecting(Connecting),
    Connected(JabberStream<Tls>),
}

impl ConnectionState {
    pub async fn connect(
        mut self,
        jid: &mut JID,
        auth: Arc<SASLConfig>,
        server: &mut String,
    ) -> Result<Self> {
        loop {
            match self {
                ConnectionState::Disconnected => {
                    self = ConnectionState::Connecting(Connecting::start(&server).await?);
                }
                ConnectionState::Connecting(connecting) => match connecting {
                    Connecting::InsecureConnectionEstablised(tcp_stream) => {
                        self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
                            JabberStream::start_stream(tcp_stream, server).await?,
                        ))
                    }
                    Connecting::InsecureStreamStarted(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
                            jabber_stream.get_features().await?,
                        ))
                    }
                    Connecting::InsecureGotFeatures((features, jabber_stream)) => {
                        match features.negotiate()? {
                            Feature::StartTls(_start_tls) => {
                                self =
                                    ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
                            }
                            // TODO: better error
                            _ => return Err(Error::TlsRequired),
                        }
                    }
                    Connecting::StartTls(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
                            jabber_stream.starttls(&server).await?,
                        ))
                    }
                    Connecting::ConnectionEstablished(tls_stream) => {
                        self = ConnectionState::Connecting(Connecting::StreamStarted(
                            JabberStream::start_stream(tls_stream, server).await?,
                        ))
                    }
                    Connecting::StreamStarted(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::GotFeatures(
                            jabber_stream.get_features().await?,
                        ))
                    }
                    Connecting::GotFeatures((features, jabber_stream)) => {
                        match features.negotiate()? {
                            Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
                            Feature::Sasl(mechanisms) => {
                                self = ConnectionState::Connecting(Connecting::Sasl(
                                    mechanisms,
                                    jabber_stream,
                                ))
                            }
                            Feature::Bind => {
                                self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
                            }
                            Feature::Unknown => return Err(Error::Unsupported),
                        }
                    }
                    Connecting::Sasl(mechanisms, jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
                            jabber_stream.sasl(mechanisms, auth.clone()).await?,
                        ))
                    }
                    Connecting::Bind(jabber_stream) => {
                        self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
                    }
                },
                connected => return Ok(connected),
            }
        }
    }
}

pub enum Connecting {
    InsecureConnectionEstablised(Unencrypted),
    InsecureStreamStarted(JabberStream<Unencrypted>),
    InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
    StartTls(JabberStream<Unencrypted>),
    ConnectionEstablished(Tls),
    StreamStarted(JabberStream<Tls>),
    GotFeatures((Features, JabberStream<Tls>)),
    Sasl(Mechanisms, JabberStream<Tls>),
    Bind(JabberStream<Tls>),
}

impl Connecting {
    pub async fn start(server: &str) -> Result<Self> {
        match Connection::connect(server).await? {
            Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
            Connection::Unencrypted(tcp_stream) => {
                Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
            }
        }
    }
}

impl Features {
    pub fn negotiate(self) -> Result<Feature> {
        if let Some(Feature::StartTls(s)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::StartTls(_s)))
        {
            // TODO: avoid clone
            return Ok(Feature::StartTls(s.clone()));
        } else if let Some(Feature::Sasl(mechanisms)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::Sasl(_)))
        {
            // TODO: avoid clone
            return Ok(Feature::Sasl(mechanisms.clone()));
        } else if let Some(Feature::Bind) = self
            .features
            .into_iter()
            .find(|feature| matches!(feature, Feature::Bind))
        {
            Ok(Feature::Bind)
        } else {
            // TODO: better error
            return Err(Error::Negotiation);
        }
    }
}

pub enum InsecureConnecting {
    Disconnected,
    ConnectionEstablished(Connection),
    PreStarttls(JabberStream<Unencrypted>),
    PreAuthenticated(JabberStream<Tls>),
    Authenticated(Tls),
    PreBound(JabberStream<Tls>),
    Bound(JabberStream<Tls>),
}

impl Sink<Stanza> for JabberClient {
    type Error = Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }

    fn start_send(
        self: std::pin::Pin<&mut Self>,
        item: Stanza,
    ) -> std::result::Result<(), Self::Error> {
        todo!()
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::JabberClient;
    use test_log::test;
    use tokio::time::sleep;

    #[test(tokio::test)]
    async fn login() {
        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        client.connect().await.unwrap();
        sleep(Duration::from_secs(5)).await
    }
}