aboutsummaryrefslogblamecommitdiffstats
path: root/src/client.rs
blob: 290834631c8f7227d935d9672441c09b96c32779 (plain) (tree)



















































































































































































                                                                                                
use std::sync::Arc;

use futures::{Sink, Stream};
use rsasl::config::SASLConfig;

use crate::{
    connection::{Tls, Unencrypted},
    stanza::{
        client::Stanza,
        sasl::Mechanisms,
        stream::{Feature, Features},
    },
    Connection, Error, JabberStream, Result, JID,
};

// feed it client stanzas, receive client stanzas
pub struct JabberClient {
    connection: JabberState,
    jid: JID,
    password: Arc<SASLConfig>,
    server: String,
}

pub enum JabberState {
    Disconnected,
    InsecureConnectionEstablised(Unencrypted),
    InsecureStreamStarted(JabberStream<Unencrypted>),
    InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
    StartTls(JabberStream<Unencrypted>),
    ConnectionEstablished(Tls),
    StreamStarted(JabberStream<Tls>),
    GotFeatures((Features, JabberStream<Tls>)),
    Sasl(Mechanisms, JabberStream<Tls>),
    Bind(JabberStream<Tls>),
    // when it's bound, can stream stanzas and sink stanzas
    Bound(JabberStream<Tls>),
}

impl JabberState {
    pub async fn advance_state(
        self,
        jid: &mut JID,
        auth: Arc<SASLConfig>,
        server: &mut String,
    ) -> Result<JabberState> {
        match self {
            JabberState::Disconnected => match Connection::connect(server).await? {
                Connection::Encrypted(tls_stream) => {
                    Ok(JabberState::ConnectionEstablished(tls_stream))
                }
                Connection::Unencrypted(tcp_stream) => {
                    Ok(JabberState::InsecureConnectionEstablised(tcp_stream))
                }
            },
            JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({
                JabberState::InsecureStreamStarted(
                    JabberStream::start_stream(tcp_stream, server).await?,
                )
            }),
            JabberState::InsecureStreamStarted(jabber_stream) => Ok(
                JabberState::InsecureGotFeatures(jabber_stream.get_features().await?),
            ),
            JabberState::InsecureGotFeatures((features, jabber_stream)) => {
                match features.negotiate()? {
                    Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)),
                    // TODO: better error
                    _ => return Err(Error::TlsRequired),
                }
            }
            JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished(
                jabber_stream.starttls(server).await?,
            )),
            JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted(
                JabberStream::start_stream(tls_stream, server).await?,
            )),
            JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures(
                jabber_stream.get_features().await?,
            )),
            JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? {
                Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
                Feature::Sasl(mechanisms) => {
                    return Ok(JabberState::Sasl(mechanisms, jabber_stream))
                }
                Feature::Bind => return Ok(JabberState::Bind(jabber_stream)),
                Feature::Unknown => return Err(Error::Unsupported),
            },
            JabberState::Sasl(mechanisms, jabber_stream) => {
                return Ok(JabberState::ConnectionEstablished(
                    jabber_stream.sasl(mechanisms, auth).await?,
                ))
            }
            JabberState::Bind(jabber_stream) => {
                Ok(JabberState::Bound(jabber_stream.bind(jid).await?))
            }
            JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)),
        }
    }
}

impl Features {
    pub fn negotiate(self) -> Result<Feature> {
        if let Some(Feature::StartTls(s)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::StartTls(_s)))
        {
            // TODO: avoid clone
            return Ok(Feature::StartTls(s.clone()));
        } else if let Some(Feature::Sasl(mechanisms)) = self
            .features
            .iter()
            .find(|feature| matches!(feature, Feature::Sasl(_)))
        {
            // TODO: avoid clone
            return Ok(Feature::Sasl(mechanisms.clone()));
        } else if let Some(Feature::Bind) = self
            .features
            .into_iter()
            .find(|feature| matches!(feature, Feature::Bind))
        {
            Ok(Feature::Bind)
        } else {
            // TODO: better error
            return Err(Error::Negotiation);
        }
    }
}

pub enum InsecureJabberConnection {
    Disconnected,
    ConnectionEstablished(Connection),
    PreStarttls(JabberStream<Unencrypted>),
    PreAuthenticated(JabberStream<Tls>),
    Authenticated(Tls),
    PreBound(JabberStream<Tls>),
    Bound(JabberStream<Tls>),
}

impl Stream for JabberClient {
    type Item = Stanza;

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        todo!()
    }
}

impl Sink<Stanza> for JabberClient {
    type Error = Error;

    fn poll_ready(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }

    fn start_send(
        self: std::pin::Pin<&mut Self>,
        item: Stanza,
    ) -> std::result::Result<(), Self::Error> {
        todo!()
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }

    fn poll_close(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
        todo!()
    }
}