aboutsummaryrefslogblamecommitdiffstats
path: root/jabber/src/client.rs
blob: 9d3268250488f12412dd6b2ee4aa078ab5b8fefa (plain) (tree)
1
2
3
4
5
6
7
8
9






                        
 
                                                           
                    
                              




                                
                       


                                   
                                                                        




                                                 
                                               
             
                                                                                                     



                              











                                                             
                             





                                   



                              

                                                   





                                                                                             

                      

         
 

                                                                      

     










                                                                                         











                                                                          
             

                                                                                                 
             







                                                                           
             


                                                                                             
             


                                                                                                    
             







                                                                                         
                     










                                                                                  





                     








                                                               

 





                                                                                                   
             



         
                             








                                           

            
                                         

                            







                                      
                       

                                          






                                                                              







                                                                              
                                                                         


                     
                     
                                           







                                                                     

                              
                     
                                           







                                                                     

                              

                   


                                                                      



                 
 
use std::{
    borrow::Borrow,
    future::Future,
    pin::pin,
    sync::Arc,
    task::{ready, Poll},
};

use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use jid::ParseError;
use rsasl::config::SASLConfig;
use stanza::{
    client::Stanza,
    sasl::Mechanisms,
    stream::{Feature, Features},
};
use tokio::sync::Mutex;

use crate::{
    connection::{Tls, Unencrypted},
    jabber_stream::bound_stream::{BoundJabberReader, BoundJabberStream},
    Connection, Error, JabberStream, Result, JID,
};

// feed it client stanzas, receive client stanzas
pub struct JabberClient {
    connection: Option<BoundJabberStream<Tls>>,
    jid: JID,
    // TODO: have reconnection be handled by another part, so creds don't need to be stored in object
    password: Arc<SASLConfig>,
    server: String,
}

impl JabberClient {
    pub fn new(
        jid: impl TryInto<JID, Error = ParseError>,
        password: impl ToString,
    ) -> Result<JabberClient> {
        let jid = jid.try_into()?;
        let sasl_config = SASLConfig::with_credentials(
            None,
            jid.localpart.clone().ok_or(Error::NoLocalpart)?,
            password.to_string(),
        )?;
        Ok(JabberClient {
            connection: None,
            jid: jid.clone(),
            password: sasl_config,
            server: jid.domainpart,
        })
    }

    pub fn jid(&self) -> JID {
        self.jid.clone()
    }

    pub async fn connect(&mut self) -> Result<()> {
        match &self.connection {
            Some(_) => Ok(()),
            None => {
                self.connection = Some(
                    connect_and_login(&mut self.jid, self.password.clone(), &mut self.server)
                        .await?,
                );
                Ok(())
            }
        }
    }

    pub(crate) fn into_inner(self) -> Result<BoundJabberStream<Tls>> {
        self.connection.ok_or(Error::Disconnected)
    }

    // pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> {
    //     match &mut self.connection {
    //         ConnectionState::Disconnected => return Err(Error::Disconnected),
    //         ConnectionState::Connecting(_connecting) => return Err(Error::Connecting),
    //         ConnectionState::Connected(jabber_stream) => {
    //             Ok(jabber_stream.send_stanza(stanza).await?)
    //         }
    //     }
    // }
}

pub async fn connect_and_login(
    jid: &mut JID,
    auth: Arc<SASLConfig>,
    server: &mut String,
) -> Result<BoundJabberStream<Tls>> {
    let mut conn_state = Connecting::start(&server).await?;
    loop {
        match conn_state {
            Connecting::InsecureConnectionEstablised(tcp_stream) => {
                conn_state = Connecting::InsecureStreamStarted(
                    JabberStream::start_stream(tcp_stream, server).await?,
                )
            }
            Connecting::InsecureStreamStarted(jabber_stream) => {
                conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?)
            }
            Connecting::InsecureGotFeatures((features, jabber_stream)) => {
                match features.negotiate().ok_or(Error::Negotiation)? {
                    Feature::StartTls(_start_tls) => {
                        conn_state = Connecting::StartTls(jabber_stream)
                    }
                    // TODO: better error
                    _ => return Err(Error::TlsRequired),
                }
            }
            Connecting::StartTls(jabber_stream) => {
                conn_state =
                    Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?)
            }
            Connecting::ConnectionEstablished(tls_stream) => {
                conn_state =
                    Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?)
            }
            Connecting::StreamStarted(jabber_stream) => {
                conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?)
            }
            Connecting::GotFeatures((features, jabber_stream)) => {
                match features.negotiate().ok_or(Error::Negotiation)? {
                    Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
                    Feature::Sasl(mechanisms) => {
                        conn_state = Connecting::Sasl(mechanisms, jabber_stream)
                    }
                    Feature::Bind => conn_state = Connecting::Bind(jabber_stream),
                    Feature::Unknown => return Err(Error::Unsupported),
                }
            }
            Connecting::Sasl(mechanisms, jabber_stream) => {
                conn_state = Connecting::ConnectionEstablished(
                    jabber_stream.sasl(mechanisms, auth.clone()).await?,
                )
            }
            Connecting::Bind(jabber_stream) => {
                return Ok(jabber_stream.bind(jid).await?.to_bound_jabber());
            }
        }
    }
}

pub enum Connecting {
    InsecureConnectionEstablised(Unencrypted),
    InsecureStreamStarted(JabberStream<Unencrypted>),
    InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
    StartTls(JabberStream<Unencrypted>),
    ConnectionEstablished(Tls),
    StreamStarted(JabberStream<Tls>),
    GotFeatures((Features, JabberStream<Tls>)),
    Sasl(Mechanisms, JabberStream<Tls>),
    Bind(JabberStream<Tls>),
}

impl Connecting {
    pub async fn start(server: &str) -> Result<Self> {
        match Connection::connect(server).await? {
            Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
            Connection::Unencrypted(tcp_stream) => {
                Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
            }
        }
    }
}

pub enum InsecureConnecting {
    Disconnected,
    ConnectionEstablished(Connection),
    PreStarttls(JabberStream<Unencrypted>),
    PreAuthenticated(JabberStream<Tls>),
    Authenticated(Tls),
    PreBound(JabberStream<Tls>),
    Bound(JabberStream<Tls>),
}

#[cfg(test)]
mod tests {
    use std::{sync::Arc, time::Duration};

    use super::JabberClient;
    use futures::{SinkExt, StreamExt};
    use stanza::{
        client::{
            iq::{Iq, IqType, Query},
            Stanza,
        },
        xep_0199::Ping,
    };
    use test_log::test;
    use tokio::{sync::Mutex, time::sleep};
    use tracing::info;

    #[test(tokio::test)]
    async fn login() {
        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        client.connect().await.unwrap();
        sleep(Duration::from_secs(5)).await
    }

    #[test(tokio::test)]
    async fn ping_parallel() {
        let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
        client.connect().await.unwrap();
        sleep(Duration::from_secs(5)).await;
        let jid = client.jid.clone();
        let server = client.server.clone();
        let (mut read, mut write) = client.into_inner().unwrap().split();

        tokio::join!(
            async {
                write
                    .write(&Stanza::Iq(Iq {
                        from: Some(jid.clone()),
                        id: "c2s1".to_string(),
                        to: Some(server.clone().try_into().unwrap()),
                        r#type: IqType::Get,
                        lang: None,
                        query: Some(Query::Ping(Ping)),
                        errors: Vec::new(),
                    }))
                    .await
                    .unwrap();
                write
                    .write(&Stanza::Iq(Iq {
                        from: Some(jid.clone()),
                        id: "c2s2".to_string(),
                        to: Some(server.clone().try_into().unwrap()),
                        r#type: IqType::Get,
                        lang: None,
                        query: Some(Query::Ping(Ping)),
                        errors: Vec::new(),
                    }))
                    .await
                    .unwrap();
            },
            async {
                for _ in 0..2 {
                    let stanza = read.read::<Stanza>().await.unwrap();
                    info!("ping reply: {:#?}", stanza);
                }
            }
        );
    }
}