aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar cel 🌸 <cel@bunny.garden>2024-12-04 02:09:07 +0000
committerLibravatar cel 🌸 <cel@bunny.garden>2024-12-04 02:09:07 +0000
commit4886396044356d2676a77c3900af796fe7641f42 (patch)
tree685c67b7db0f22a7262fc6431d7849d63c510e66
parente0373c0520e7fae792bc907e9c500ab846d34e31 (diff)
downloadluz-4886396044356d2676a77c3900af796fe7641f42.tar.gz
luz-4886396044356d2676a77c3900af796fe7641f42.tar.bz2
luz-4886396044356d2676a77c3900af796fe7641f42.zip
implement client
-rw-r--r--src/client.rs234
-rw-r--r--src/error.rs9
-rw-r--r--src/jabber.rs21
-rw-r--r--src/lib.rs8
-rw-r--r--src/stanza/client/mod.rs27
5 files changed, 213 insertions, 86 deletions
diff --git a/src/client.rs b/src/client.rs
index 2908346..5351b34 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1,10 +1,11 @@
-use std::sync::Arc;
+use std::{pin::pin, sync::Arc, task::Poll};
-use futures::{Sink, Stream};
+use futures::{Sink, Stream, StreamExt};
use rsasl::config::SASLConfig;
use crate::{
connection::{Tls, Unencrypted},
+ jid::ParseError,
stanza::{
client::Stanza,
sasl::Mechanisms,
@@ -15,14 +16,146 @@ use crate::{
// feed it client stanzas, receive client stanzas
pub struct JabberClient {
- connection: JabberState,
+ connection: ConnectionState,
jid: JID,
password: Arc<SASLConfig>,
server: String,
}
-pub enum JabberState {
+impl JabberClient {
+ pub fn new(
+ jid: impl TryInto<JID, Error = ParseError>,
+ password: impl ToString,
+ ) -> Result<JabberClient> {
+ let jid = jid.try_into()?;
+ let sasl_config = SASLConfig::with_credentials(
+ None,
+ jid.localpart.clone().ok_or(Error::NoLocalpart)?,
+ password.to_string(),
+ )?;
+ Ok(JabberClient {
+ connection: ConnectionState::Disconnected,
+ jid: jid.clone(),
+ password: sasl_config,
+ server: jid.domainpart,
+ })
+ }
+
+ pub async fn connect(&mut self) -> Result<()> {
+ match &self.connection {
+ ConnectionState::Disconnected => {
+ self.connection = ConnectionState::Disconnected
+ .connect(&mut self.jid, self.password.clone(), &mut self.server)
+ .await?;
+ Ok(())
+ }
+ ConnectionState::Connecting(_connecting) => Err(Error::AlreadyConnecting),
+ ConnectionState::Connected(_jabber_stream) => Ok(()),
+ }
+ }
+}
+
+impl Stream for JabberClient {
+ type Item = Result<Stanza>;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ let mut client = pin!(self);
+ match &mut client.connection {
+ ConnectionState::Disconnected => Poll::Pending,
+ ConnectionState::Connecting(_connecting) => Poll::Pending,
+ ConnectionState::Connected(jabber_stream) => jabber_stream.poll_next_unpin(cx),
+ }
+ }
+}
+
+pub enum ConnectionState {
Disconnected,
+ Connecting(Connecting),
+ Connected(JabberStream<Tls>),
+}
+
+impl ConnectionState {
+ pub async fn connect(
+ mut self,
+ jid: &mut JID,
+ auth: Arc<SASLConfig>,
+ server: &mut String,
+ ) -> Result<Self> {
+ loop {
+ match self {
+ ConnectionState::Disconnected => {
+ self = ConnectionState::Connecting(Connecting::start(&server).await?);
+ }
+ ConnectionState::Connecting(connecting) => match connecting {
+ Connecting::InsecureConnectionEstablised(tcp_stream) => {
+ self = ConnectionState::Connecting(Connecting::InsecureStreamStarted(
+ JabberStream::start_stream(tcp_stream, server).await?,
+ ))
+ }
+ Connecting::InsecureStreamStarted(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::InsecureGotFeatures(
+ jabber_stream.get_features().await?,
+ ))
+ }
+ Connecting::InsecureGotFeatures((features, jabber_stream)) => {
+ match features.negotiate()? {
+ Feature::StartTls(_start_tls) => {
+ self =
+ ConnectionState::Connecting(Connecting::StartTls(jabber_stream))
+ }
+ // TODO: better error
+ _ => return Err(Error::TlsRequired),
+ }
+ }
+ Connecting::StartTls(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
+ jabber_stream.starttls(&server).await?,
+ ))
+ }
+ Connecting::ConnectionEstablished(tls_stream) => {
+ self = ConnectionState::Connecting(Connecting::StreamStarted(
+ JabberStream::start_stream(tls_stream, server).await?,
+ ))
+ }
+ Connecting::StreamStarted(jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::GotFeatures(
+ jabber_stream.get_features().await?,
+ ))
+ }
+ Connecting::GotFeatures((features, jabber_stream)) => {
+ match features.negotiate()? {
+ Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
+ Feature::Sasl(mechanisms) => {
+ self = ConnectionState::Connecting(Connecting::Sasl(
+ mechanisms,
+ jabber_stream,
+ ))
+ }
+ Feature::Bind => {
+ self = ConnectionState::Connecting(Connecting::Bind(jabber_stream))
+ }
+ Feature::Unknown => return Err(Error::Unsupported),
+ }
+ }
+ Connecting::Sasl(mechanisms, jabber_stream) => {
+ self = ConnectionState::Connecting(Connecting::ConnectionEstablished(
+ jabber_stream.sasl(mechanisms, auth.clone()).await?,
+ ))
+ }
+ Connecting::Bind(jabber_stream) => {
+ self = ConnectionState::Connected(jabber_stream.bind(jid).await?)
+ }
+ },
+ connected => return Ok(connected),
+ }
+ }
+ }
+}
+
+pub enum Connecting {
InsecureConnectionEstablised(Unencrypted),
InsecureStreamStarted(JabberStream<Unencrypted>),
InsecureGotFeatures((Features, JabberStream<Unencrypted>)),
@@ -32,67 +165,15 @@ pub enum JabberState {
GotFeatures((Features, JabberStream<Tls>)),
Sasl(Mechanisms, JabberStream<Tls>),
Bind(JabberStream<Tls>),
- // when it's bound, can stream stanzas and sink stanzas
- Bound(JabberStream<Tls>),
}
-impl JabberState {
- pub async fn advance_state(
- self,
- jid: &mut JID,
- auth: Arc<SASLConfig>,
- server: &mut String,
- ) -> Result<JabberState> {
- match self {
- JabberState::Disconnected => match Connection::connect(server).await? {
- Connection::Encrypted(tls_stream) => {
- Ok(JabberState::ConnectionEstablished(tls_stream))
- }
- Connection::Unencrypted(tcp_stream) => {
- Ok(JabberState::InsecureConnectionEstablised(tcp_stream))
- }
- },
- JabberState::InsecureConnectionEstablised(tcp_stream) => Ok({
- JabberState::InsecureStreamStarted(
- JabberStream::start_stream(tcp_stream, server).await?,
- )
- }),
- JabberState::InsecureStreamStarted(jabber_stream) => Ok(
- JabberState::InsecureGotFeatures(jabber_stream.get_features().await?),
- ),
- JabberState::InsecureGotFeatures((features, jabber_stream)) => {
- match features.negotiate()? {
- Feature::StartTls(_start_tls) => Ok(JabberState::StartTls(jabber_stream)),
- // TODO: better error
- _ => return Err(Error::TlsRequired),
- }
- }
- JabberState::StartTls(jabber_stream) => Ok(JabberState::ConnectionEstablished(
- jabber_stream.starttls(server).await?,
- )),
- JabberState::ConnectionEstablished(tls_stream) => Ok(JabberState::StreamStarted(
- JabberStream::start_stream(tls_stream, server).await?,
- )),
- JabberState::StreamStarted(jabber_stream) => Ok(JabberState::GotFeatures(
- jabber_stream.get_features().await?,
- )),
- JabberState::GotFeatures((features, jabber_stream)) => match features.negotiate()? {
- Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls),
- Feature::Sasl(mechanisms) => {
- return Ok(JabberState::Sasl(mechanisms, jabber_stream))
- }
- Feature::Bind => return Ok(JabberState::Bind(jabber_stream)),
- Feature::Unknown => return Err(Error::Unsupported),
- },
- JabberState::Sasl(mechanisms, jabber_stream) => {
- return Ok(JabberState::ConnectionEstablished(
- jabber_stream.sasl(mechanisms, auth).await?,
- ))
+impl Connecting {
+ pub async fn start(server: &str) -> Result<Self> {
+ match Connection::connect(server).await? {
+ Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)),
+ Connection::Unencrypted(tcp_stream) => {
+ Ok(Connecting::InsecureConnectionEstablised(tcp_stream))
}
- JabberState::Bind(jabber_stream) => {
- Ok(JabberState::Bound(jabber_stream.bind(jid).await?))
- }
- JabberState::Bound(jabber_stream) => Ok(JabberState::Bound(jabber_stream)),
}
}
}
@@ -126,7 +207,7 @@ impl Features {
}
}
-pub enum InsecureJabberConnection {
+pub enum InsecureConnecting {
Disconnected,
ConnectionEstablished(Connection),
PreStarttls(JabberStream<Unencrypted>),
@@ -136,17 +217,6 @@ pub enum InsecureJabberConnection {
Bound(JabberStream<Tls>),
}
-impl Stream for JabberClient {
- type Item = Stanza;
-
- fn poll_next(
- self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Self::Item>> {
- todo!()
- }
-}
-
impl Sink<Stanza> for JabberClient {
type Error = Error;
@@ -178,3 +248,19 @@ impl Sink<Stanza> for JabberClient {
todo!()
}
}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use super::JabberClient;
+ use test_log::test;
+ use tokio::time::sleep;
+
+ #[test(tokio::test)]
+ async fn login() {
+ let mut client = JabberClient::new("test@blos.sm", "slayed").unwrap();
+ client.connect().await.unwrap();
+ sleep(Duration::from_secs(5)).await
+ }
+}
diff --git a/src/error.rs b/src/error.rs
index 8cb6496..f117e82 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -13,8 +13,11 @@ pub enum Error {
TlsRequired,
AlreadyTls,
Unsupported,
+ NoLocalpart,
+ AlreadyConnecting,
UnexpectedElement(peanuts::Element),
XML(peanuts::Error),
+ Deserialization(peanuts::DeserializeError),
SASL(SASLError),
JID(ParseError),
Authentication(Failure),
@@ -34,6 +37,12 @@ impl From<rsasl::prelude::SASLError> for Error {
}
}
+impl From<peanuts::DeserializeError> for Error {
+ fn from(e: peanuts::DeserializeError) -> Self {
+ Error::Deserialization(e)
+ }
+}
+
impl From<MechanismNameError> for Error {
fn from(e: MechanismNameError) -> Self {
Self::SASL(SASLError::MechanismName(e))
diff --git a/src/jabber.rs b/src/jabber.rs
index cf90f73..30dc15d 100644
--- a/src/jabber.rs
+++ b/src/jabber.rs
@@ -1,8 +1,10 @@
+use std::pin::pin;
use std::str::{self, FromStr};
use std::sync::Arc;
use async_recursion::async_recursion;
-use peanuts::element::IntoElement;
+use futures::StreamExt;
+use peanuts::element::{FromContent, IntoElement};
use peanuts::{Reader, Writer};
use rsasl::prelude::{Mechname, SASLClient, SASLConfig};
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
@@ -13,6 +15,7 @@ use crate::connection::{Tls, Unencrypted};
use crate::error::Error;
use crate::stanza::bind::{Bind, BindType, FullJidType, ResourceType};
use crate::stanza::client::iq::{Iq, IqType, Query};
+use crate::stanza::client::Stanza;
use crate::stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse};
use crate::stanza::starttls::{Proceed, StartTls};
use crate::stanza::stream::{Feature, Features, Stream};
@@ -26,6 +29,22 @@ pub struct JabberStream<S> {
writer: Writer<WriteHalf<S>>,
}
+impl<S: AsyncRead> futures::Stream for JabberStream<S> {
+ type Item = Result<Stanza>;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ pin!(self).reader.poll_next_unpin(cx).map(|content| {
+ content.map(|content| -> Result<Stanza> {
+ let stanza = content.map(|content| Stanza::from_content(content))?;
+ Ok(stanza?)
+ })
+ })
+ }
+}
+
impl<S> JabberStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug,
diff --git a/src/lib.rs b/src/lib.rs
index 9c8d968..e55d3f5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -29,8 +29,8 @@ pub async fn login<J: AsRef<str>, P: AsRef<str>>(jid: J, password: P) -> Result<
#[cfg(test)]
mod tests {
- #[tokio::test]
- async fn test_login() {
- crate::login("test@blos.sm/clown", "slayed").await.unwrap();
- }
+ // #[tokio::test]
+ // async fn test_login() {
+ // crate::login("test@blos.sm/clown", "slayed").await.unwrap();
+ // }
}
diff --git a/src/stanza/client/mod.rs b/src/stanza/client/mod.rs
index 25d7b56..2b063d6 100644
--- a/src/stanza/client/mod.rs
+++ b/src/stanza/client/mod.rs
@@ -1,7 +1,7 @@
use iq::Iq;
use message::Message;
use peanuts::{
- element::{FromElement, IntoElement},
+ element::{Content, ContentBuilder, FromContent, FromElement, IntoContent, IntoElement},
DeserializeError,
};
use presence::Presence;
@@ -20,6 +20,18 @@ pub enum Stanza {
Presence(Presence),
Iq(Iq),
Error(StreamError),
+ OtherContent(Content),
+}
+
+impl FromContent for Stanza {
+ fn from_content(content: Content) -> peanuts::element::DeserializeResult<Self> {
+ match content {
+ Content::Element(element) => Ok(Stanza::from_element(element)?),
+ Content::Text(_) => Ok(Stanza::OtherContent(content)),
+ Content::PI => Ok(Stanza::OtherContent(content)),
+ Content::Comment(_) => Ok(Stanza::OtherContent(content)),
+ }
+ }
}
impl FromElement for Stanza {
@@ -36,13 +48,14 @@ impl FromElement for Stanza {
}
}
-impl IntoElement for Stanza {
- fn builder(&self) -> peanuts::element::ElementBuilder {
+impl IntoContent for Stanza {
+ fn builder(&self) -> peanuts::element::ContentBuilder {
match self {
- Stanza::Message(message) => message.builder(),
- Stanza::Presence(presence) => presence.builder(),
- Stanza::Iq(iq) => iq.builder(),
- Stanza::Error(error) => error.builder(),
+ Stanza::Message(message) => <Message as IntoContent>::builder(message),
+ Stanza::Presence(presence) => <Presence as IntoContent>::builder(presence),
+ Stanza::Iq(iq) => <Iq as IntoContent>::builder(iq),
+ Stanza::Error(error) => <StreamError as IntoContent>::builder(error),
+ Stanza::OtherContent(_content) => ContentBuilder::Comment("other-content".to_string()),
}
}
}