aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/client.rs
blob: f5d5dc76287d85650818c6496b746b2154e8347a (plain) (tree)
1
2
3
4
5
6
7
8
9






                        
 
                                                           
                    
                              




                                
                       


                                   
                                                   




                                                 
                                




                              





















                                                             

                                                                                                                               








                                                                                      
 
                                                                 






                                                                                      


















































                                                                                         
     

 
                          
                 
                           









































































                                                                                         

























                                                                                             
                                                                               























                                                                                                    
                                                                               


















                                                                                                   


                                                                             








                                                  








                                                               

 





                                                                                                   
             



         
                             








                                           

            
                                         

                            







                                      
                       

                                          






                                                                              







                                                                              
                                                   


                     
                     









                                                                     
                     











                                                                     
                                                            




                                           
 
use std::{
    borrow::Borrow,
    future::Future,
    pin::pin,
    sync::Arc,
    task::{ready, Poll},
};

use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
    client::Stanza,
    sasl::Mechanisms,
    stream::{Feature, Features},
};
use tokio::sync::Mutex;

use crate::{
    connection::{Tls, Unencrypted},
    jabber_stream::bound_stream::BoundJabberStream,
    Connection, Error, JabberStream, Result, JID,
};

// feed it client stanzas, receive client stanzas
pub struct JabberClient {
    connection: ConnectionState,
    jid: JID,
    password: Arc<SASLConfig>,
    server: String,
}

impl JabberClient {
    pub fn new(
        jid: impl TryInto<JID, Error = ParseError>,
        password: impl ToString,
    ) -> Result<JabberClient> {
        let jid = jid.try_into()?;
        let sasl_config = SASLConfig::with_credentials(
            None,
            jid.localpart.clone().ok_or(Error::NoLocalpart)?,
            password.to_string(),
        )?;
        Ok(JabberClient {
            connection: ConnectionState::Disconnected,
            jid: jid.clone(),
            password: sasl_config,
            server: jid.domainpart,
        })
    }

    pub async fn connect(&mut self) -> Result<()> {
        match &self.connection {
            ConnectionState::Disconnected => {
                // TODO: actually set the self.connection as it is connecting, make more asynchronous (mutex while connecting?)
                // perhaps use take_mut?
                self.connection = ConnectionState::Disconnected
                    .connect(&mut self.jid, self.password.clone(), &mut self.server)
                    .await?;
                Ok(())
            }
            ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
            ConnectionState::Connected(_jabber_stream) => Ok(()),
        }
    }

    pub(crate) fn inner(self) -> Result<BoundJabberStream<Tls>> {
        match self.connection {
            ConnectionState::Disconnected => return Err(Error::Disconnected),
            ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
            ConnectionState::Connected(jabber_stream) => return Ok(jabber_stream),
        }
    }

    // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
    //     match &mut self.connection {
    //         ConnectionState::Disconnected => return Err(Error::Disconnected),
    //         ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
    //         ConnectionState::Connected(jabber_stream) => {
    //             Ok(jabber_stream.send_stanza(stanza).await?)
    //         }
    //     }
    // }
}

impl Sink<Stanza> for JabberClient {
    type Error = Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        self.get_mut().connection.poll_ready_unpin(cx)
    }

    fn start_send(
        self: std::pin::Pin<&mut Self>,
        item: Stanza,
    ) -> std::result::Result<(), Self::Error> {
        self.get_mut().connection.start_send_unpin(item)
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        self.get_mut().connection.poll_flush_unpin(cx)
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        self.get_mut().connection.poll_flush_unpin(cx)
    }
}

impl Stream for JabberClient {
    type Item = Result<Stanza>;

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        self.get_mut().connection.poll_next_unpin(cx)
    }
}

pub enum ConnectionState {
    Disconnected,
    Connecting(Connecting),
    Connected(BoundJabberStream<Tls>),
}

impl Sink<Stanza> for ConnectionState {
    type Error = Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        match self.get_mut() {
            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
            ConnectionState::Connecting(_connecting) => Poll::Pending,
            ConnectionState::Connected(bound_jabber_stream) => {
                bound_jabber_stream.poll_ready_unpin(cx)
            }
        }
    }

    fn start_send(
        self: std::pin::Pin<&mut Self>,
        item: Stanza,
    ) -> std::result::Result<(), Self::Error> {
        match self.get_mut() {
            ConnectionState::Disconnected => Err(Error::Disconnected),
            ConnectionState::Connecting(_connecting) => Err(Error::Connecting),
            ConnectionState::Connected(bound_jabber_stream) => {
                bound_jabber_stream.start_send_unpin(item)
            }
        }
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        match self.get_mut() {
            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
            ConnectionState::Connecting(_connecting) => Poll::Pending,
            ConnectionState::Connected(bound_jabber_stream) => {
                bound_jabber_stream.poll_flush_unpin(cx)
            }
        }
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::result::Result<(), Self::Error>> {
        match self.get_mut() {
            ConnectionState::Disconnected => Poll::Ready(Err(Error::Disconnected)),
            ConnectionState::Connecting(_connecting) => Poll::Pending,
            ConnectionState::Connected(bound_jabber_stream) => {
                bound_jabber_stream.poll_close_unpin(cx)
            }
        }
    }
}

impl Stream for ConnectionState {
    type Item = Result<Stanza>;

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match self.get_mut() {
            ConnectionState::Disconnected => Poll::Ready(Some(Err(Error::Disconnected))),
            ConnectionState::Connecting(_connecting) => Poll::Pending,
            ConnectionState::Connected(bound_jabber_stream) => {
                bound_jabber_stream.poll_next_unpin(cx)
            }
        }
    }
}

impl ConnectionState {
    pub async fn connect(
        mut self,
        jid: &mut JID,
        auth: Arc<SASLConfig>,
        server: &mut String,
    ) -> Result<Self> {
        loop {
            match self {
                ConnectionState::Disconnected => {
                    self = ConnectionState::Connecting(Connecting::start(&server).await?);
                }
                ConnectionState::Connecting(connecting) => match connecting {
                    Connecting::InsecureConnectionEstablised(tcp_stream) => {
                        self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
                            JabberStream::start_stream(tcp_stream, server).await?,
                        ))
                    }
                    Connecting::InsecureStreamStarted(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
                            jabber_stream.get_features().await?,
                        ))
                    }
                    Connecting::InsecureGotFeatures((features, jabber_stream)) => {
                        match features.negotiate().ok_or(Error::Negotiation)? {
                            Feature::StartTls(_start_tls) => {
                                self =
                                    ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
                            }
                            // TODO: better error
                            _ => return Err(Error::TlsRequired),
                        }
                    }
                    Connecting::StartTls(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
                            jabber_stream.starttls(&server).await?,
                        ))
                    }
                    Connecting::ConnectionEstablished(tls_stream) => {
                        self = ConnectionState::Connecting(Connecting::StreamStarted(
                            JabberStream::start_stream(tls_stream, server).await?,
                        ))
                    }
                    Connecting::StreamStarted(jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::GotFeatures(
                            jabber_stream.get_features().await?,
                        ))
                    }
                    Connecting::GotFeatures((features, jabber_stream)) => {
                        match features.negotiate().ok_or(Error::Negotiation)? {
                            Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
                            Feature::Sasl(mechanisms) => {
                                self = ConnectionState::Connecting(Connecting::Sasl(
                                    mechanisms,
                                    jabber_stream,
                                ))
                            }
                            Feature::Bind => {
                                self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
                            }
                            Feature::Unknown => return Err(Error::Unsupported),
                        }
                    }
                    Connecting::Sasl(mechanisms, jabber_stream) => {
                        self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
                            jabber_stream.sasl(mechanisms, auth.clone()).await?,
                        ))
                    }
                    Connecting::Bind(jabber_stream) => {
                        self = ConnectionState::Connected(
                            jabber_stream.bind(jid).await?.to_bound_jabber(),
                        )
                    }
                },
                connected => return Ok(connected),
            }
        }
    }
}

pub enum Connecting {
    InsecureConnectionEstablised(Unencrypted),
    InsecureStreamStarted(JabberStream<Unencrypted>),
    InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
    StartTls(JabberStream<Unencrypted>),
    ConnectionEstablished(Tls),
    StreamStarted(JabberStream<Tls>),
    GotFeatures((Features, JabberStream<Tls>)),
    Sasl(Mechanisms, JabberStream<Tls>),
    Bind(JabberStream<Tls>),
}

impl Connecting {
    pub async fn start(server: &str) -> Result<Self> {
        match Connection::connect(server).await? {
            Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
            Connection::Unencrypted(tcp_stream) => {
                Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
            }
        }
    }
}

pub enum InsecureConnecting {
    Disconnected,
    ConnectionEstablished(Connection),
    PreStarttls(JabberStream<Unencrypted>),
    PreAuthenticated(JabberStream<Tls>),
    Authenticated(Tls),
    PreBound(JabberStream<Tls>),
    Bound(JabberStream<Tls>),
}

#[cfg(test)]
mod tests {
    use std::{sync::Arc, time::Duration};

    use super::JabberClient;
    use futures::{SinkExt, StreamExt};
    use stanza::{
        client::{
            iq::{Iq, IqType, Query},
            Stanza,
        },
        xep_0199::Ping,
    };
    use test_log::test;
    use tokio::{sync::Mutex, time::sleep};
    use tracing::info;

    #[test(tokio::test)]
    async fn login() {
        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        client.connect().await.unwrap();
        sleep(Duration::from_secs(5)).await
    }

    #[test(tokio::test)]
    async fn ping_parallel() {
        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        client.connect().await.unwrap();
        sleep(Duration::from_secs(5)).await;
        let jid = client.jid.clone();
        let server = client.server.clone();
        let (mut write, mut read) = client.split();

        tokio::join!(
            async {
                write
                    .send(Stanza::Iq(Iq {
                        from: Some(jid.clone()),
                        id: "c2s1".to_string(),
                        to: Some(server.clone().try_into().unwrap()),
                        r#type: IqType::Get,
                        lang: None,
                        query: Some(Query::Ping(Ping)),
                        errors: Vec::new(),
                    }))
                    .await;
                write
                    .send(Stanza::Iq(Iq {
                        from: Some(jid.clone()),
                        id: "c2s2".to_string(),
                        to: Some(server.clone().try_into().unwrap()),
                        r#type: IqType::Get,
                        lang: None,
                        query: Some(Query::Ping(Ping)),
                        errors: Vec::new(),
                    }))
                    .await;
            },
            async {
                while let Some(stanza) = read.next().await {
                    info!("{:#?}", stanza);
                }
            }
        );
    }
}