diff options
| -rw-r--r-- | Cargo.toml | 5 | ||||
| -rw-r--r-- | src/client/encrypted.rs | 59 | ||||
| -rw-r--r-- | src/client/mod.rs | 40 | ||||
| -rw-r--r-- | src/client/unencrypted.rs | 135 | ||||
| -rw-r--r-- | src/error.rs | 7 | ||||
| -rw-r--r-- | src/jabber.rs | 131 | ||||
| -rw-r--r-- | src/lib.rs | 187 | ||||
| -rw-r--r-- | src/stanza/mod.rs | 1 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 36 | 
9 files changed, 441 insertions, 160 deletions
| @@ -7,6 +7,9 @@ edition = "2021"  # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html  [dependencies] -quick-xml = { version = "0.29.0", features = ["async-tokio"] } +async-trait = "0.1.68" +quick-xml = { version = "0.29.0", features = ["async-tokio", "serialize"] } +serde = { version = "1.0.164", features = ["derive"] }  tokio = { version = "1.28", features = ["full"] } +tokio-native-tls = "0.3.1"  trust-dns-resolver = "0.22.0" diff --git a/src/client/encrypted.rs b/src/client/encrypted.rs new file mode 100644 index 0000000..08439b2 --- /dev/null +++ b/src/client/encrypted.rs @@ -0,0 +1,59 @@ +use quick_xml::{ +    events::{BytesDecl, BytesStart, Event}, +    Reader, Writer, +}; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::TlsStream; + +use crate::Jabber; +use crate::Result; + +pub struct JabberClient<'j> { +    reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, +    writer: Writer<WriteHalf<TlsStream<TcpStream>>>, +    jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { +    pub fn new( +        reader: Reader<BufReader<ReadHalf<TlsStream<TcpStream>>>>, +        writer: Writer<WriteHalf<TlsStream<TcpStream>>>, +        jabber: &'j mut Jabber<'j>, +    ) -> Self { +        Self { +            reader, +            writer, +            jabber, +        } +    } + +    pub async fn start_stream(&mut self) -> Result<()> { +        let declaration = BytesDecl::new("1.0", None, None); +        let mut stream_element = BytesStart::new("stream:stream"); +        stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); +        stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); +        stream_element.push_attribute(("version", "1.0")); +        stream_element.push_attribute(("xml:lang", "en")); +        stream_element.push_attribute(("xmlns", "jabber:client")); +        stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); +        self.writer +            .write_event_async(Event::Decl(declaration)) +            .await; +        self.writer +            .write_event_async(Event::Start(stream_element)) +            .await +            .unwrap(); +        let mut buf = Vec::new(); +        loop { +            match self.reader.read_event_into_async(&mut buf).await.unwrap() { +                Event::Start(e) => { +                    println!("{:?}", e); +                    break; +                } +                e => println!("decl: {:?}", e), +            }; +        } +        Ok(()) +    } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..fe3dd34 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,40 @@ +pub mod encrypted; +pub mod unencrypted; + +// use async_trait::async_trait; + +use crate::stanza::stream::StreamFeature; +use crate::JabberError; +use crate::Result; + +pub enum JabberClientType<'j> { +    Encrypted(encrypted::JabberClient<'j>), +    Unencrypted(unencrypted::JabberClient<'j>), +} + +impl<'j> JabberClientType<'j> { +    pub async fn ensure_tls(self) -> Result<encrypted::JabberClient<'j>> { +        match self { +            Self::Encrypted(mut c) => { +                c.start_stream(); +                Ok(c) +            } +            Self::Unencrypted(mut c) => { +                c.start_stream().await?; +                let features = c.get_features().await?; +                if features.contains(&StreamFeature::StartTls) { +                    Ok(c.starttls().await?) +                } else { +                    Err(JabberError::StartTlsUnavailable) +                } +            } +        } +    } +} + +// TODO: jabber client trait over both client types +// #[async_trait] +// pub trait JabberTrait { +//     async fn start_stream(&mut self) -> Result<()>; +//     async fn get_features(&self) -> Result<Vec<StreamFeatures>>; +// } diff --git a/src/client/unencrypted.rs b/src/client/unencrypted.rs new file mode 100644 index 0000000..7528b14 --- /dev/null +++ b/src/client/unencrypted.rs @@ -0,0 +1,135 @@ +use std::str; + +use quick_xml::{ +    de::Deserializer, +    events::{BytesDecl, BytesStart, Event}, +    name::QName, +    Reader, Writer, +}; +use serde::Deserialize; +use tokio::io::{BufReader, ReadHalf, WriteHalf}; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; + +use crate::Result; +use crate::{error::JabberError, stanza::stream::StreamFeature}; +use crate::{stanza::stream::StreamFeatures, Jabber}; + +pub struct JabberClient<'j> { +    reader: Reader<BufReader<ReadHalf<TcpStream>>>, +    writer: Writer<WriteHalf<TcpStream>>, +    jabber: &'j mut Jabber<'j>, +} + +impl<'j> JabberClient<'j> { +    pub fn new( +        reader: Reader<BufReader<ReadHalf<TcpStream>>>, +        writer: Writer<WriteHalf<TcpStream>>, +        jabber: &'j mut Jabber<'j>, +    ) -> Self { +        Self { +            reader, +            writer, +            jabber, +        } +    } + +    pub async fn start_stream(&mut self) -> Result<()> { +        let declaration = BytesDecl::new("1.0", None, None); +        let mut stream_element = BytesStart::new("stream:stream"); +        stream_element.push_attribute(("from".as_bytes(), self.jabber.jid.to_string().as_bytes())); +        stream_element.push_attribute(("to".as_bytes(), self.jabber.server.as_bytes())); +        stream_element.push_attribute(("version", "1.0")); +        stream_element.push_attribute(("xml:lang", "en")); +        stream_element.push_attribute(("xmlns", "jabber:client")); +        stream_element.push_attribute(("xmlns:stream", "http://etherx.jabber.org/streams")); +        self.writer +            .write_event_async(Event::Decl(declaration)) +            .await; +        self.writer +            .write_event_async(Event::Start(stream_element)) +            .await +            .unwrap(); +        let mut buf = Vec::new(); +        loop { +            match self.reader.read_event_into_async(&mut buf).await.unwrap() { +                Event::Start(e) => { +                    println!("{:?}", e); +                    break; +                } +                Event::Decl(e) => println!("decl: {:?}", e), +                _ => return Err(JabberError::BadStream), +            } +        } +        Ok(()) +    } + +    pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { +        let mut buf = Vec::new(); +        let mut txt = Vec::new(); +        let mut loop_end = false; +        while !loop_end { +            match self.reader.read_event_into_async(&mut buf).await.unwrap() { +                Event::End(e) => { +                    if e.name() == QName(b"stream:features") { +                        loop_end = true; +                    } +                } +                _ => (), +            } +            txt.push(b'<'); +            txt = txt +                .into_iter() +                .chain(buf.to_owned()) +                .chain(vec![b'>']) +                .collect(); +            buf.clear(); +        } +        println!("{:?}", txt); +        let decoded = str::from_utf8(&txt).unwrap(); +        println!("decoded: {:?}", decoded); +        let mut deserializer = Deserializer::from_str(decoded); +        // let mut deserializer = Deserializer::from_str(txt); +        let features = StreamFeatures::deserialize(&mut deserializer).unwrap(); +        println!("{:?}", features); +        Ok(features.features) +    } + +    pub async fn starttls(mut self) -> Result<super::encrypted::JabberClient<'j>> { +        let mut starttls_element = BytesStart::new("starttls"); +        starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); +        self.writer +            .write_event_async(Event::Empty(starttls_element)) +            .await +            .unwrap(); +        let mut buf = Vec::new(); +        match self.reader.read_event_into_async(&mut buf).await.unwrap() { +            Event::Empty(e) => match e.name() { +                QName(b"proceed") => { +                    let connector = TlsConnector::new().unwrap(); +                    let stream = self +                        .reader +                        .into_inner() +                        .into_inner() +                        .unsplit(self.writer.into_inner()); +                    if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) +                        .connect(&self.jabber.server, stream) +                        .await +                    { +                        let (read, write) = tokio::io::split(tlsstream); +                        let reader = Reader::from_reader(BufReader::new(read)); +                        let writer = Writer::new(write); +                        return Ok(super::encrypted::JabberClient::new( +                            reader, +                            writer, +                            self.jabber, +                        )); +                    } +                } +                QName(_) => return Err(JabberError::TlsNegotiation), +            }, +            _ => return Err(JabberError::TlsNegotiation), +        } +        Err(JabberError::TlsNegotiation) +    } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..a632537 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,7 @@ +#[derive(Debug)] +pub enum JabberError { +    ConnectionError, +    BadStream, +    StartTlsUnavailable, +    TlsNegotiation, +} diff --git a/src/jabber.rs b/src/jabber.rs new file mode 100644 index 0000000..a1f6272 --- /dev/null +++ b/src/jabber.rs @@ -0,0 +1,131 @@ +use std::marker::PhantomData; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; + +use quick_xml::{Reader, Writer}; +use tokio::io::BufReader; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; + +use crate::client; +use crate::client::JabberClientType; +use crate::jid::JID; +use crate::{JabberError, Result}; + +pub struct Jabber<'j> { +    pub jid: JID, +    pub password: String, +    pub server: String, +    _marker: PhantomData<&'j ()>, +} + +impl<'j> Jabber<'j> { +    pub fn new(jid: JID, password: String) -> Self { +        let server = jid.domainpart.clone(); +        Self { +            jid, +            password, +            server, +            _marker: PhantomData, +        } +    } + +    async fn get_sockets(&self) -> Vec<(SocketAddr, bool)> { +        let mut socket_addrs = Vec::new(); + +        // if it's a socket/ip then just return that + +        // socket +        if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) { +            match socket_addr.port() { +                5223 => socket_addrs.push((socket_addr, true)), +                _ => socket_addrs.push((socket_addr, false)), +            } + +            return socket_addrs; +        } +        // ip +        if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) { +            socket_addrs.push((SocketAddr::new(ip, 5222), false)); +            socket_addrs.push((SocketAddr::new(ip, 5223), true)); +            return socket_addrs; +        } + +        // otherwise resolve +        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { +            if let Ok(lookup) = resolver +                .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart)) +                .await +            { +                for srv in lookup { +                    resolver +                        .lookup_ip(srv.target().to_owned()) +                        .await +                        .map(|ips| { +                            for ip in ips { +                                socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) +                            } +                        }); +                } +            } +            if let Ok(lookup) = resolver +                .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart)) +                .await +            { +                for srv in lookup { +                    resolver +                        .lookup_ip(srv.target().to_owned()) +                        .await +                        .map(|ips| { +                            for ip in ips { +                                socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) +                            } +                        }); +                } +            } + +            // in case cannot connect through SRV records +            resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| { +                for ip in ips { +                    socket_addrs.push((SocketAddr::new(ip, 5222), false)); +                    socket_addrs.push((SocketAddr::new(ip, 5223), true)); +                } +            }); +        } +        socket_addrs +    } + +    pub async fn connect(&'j mut self) -> Result<JabberClientType> { +        for (socket_addr, is_tls) in self.get_sockets().await { +            println!("trying {}", socket_addr); +            match is_tls { +                true => { +                    let socket = TcpStream::connect(socket_addr).await.unwrap(); +                    let connector = TlsConnector::new().unwrap(); +                    if let Ok(stream) = tokio_native_tls::TlsConnector::from(connector) +                        .connect(&self.server, socket) +                        .await +                    { +                        let (read, write) = tokio::io::split(stream); +                        let reader = Reader::from_reader(BufReader::new(read)); +                        let writer = Writer::new(write); +                        return Ok(JabberClientType::Encrypted( +                            client::encrypted::JabberClient::new(reader, writer, self), +                        )); +                    } +                } +                false => { +                    if let Ok(stream) = TcpStream::connect(socket_addr).await { +                        let (read, write) = tokio::io::split(stream); +                        let reader = Reader::from_reader(BufReader::new(read)); +                        let writer = Writer::new(write); +                        return Ok(JabberClientType::Unencrypted( +                            client::unencrypted::JabberClient::new(reader, writer, self), +                        )); +                    } +                } +            } +        } +        Err(JabberError::ConnectionError) +    } +} @@ -1,174 +1,43 @@ -// TODO: logging (dropped errors)  #![allow(unused_must_use)] -use std::{ -    net::{IpAddr, SocketAddr}, -    str::FromStr, -}; - -use jid::JID; -use quick_xml::{Reader, Writer}; -use tokio::net::{ -    tcp::{OwnedReadHalf, OwnedWriteHalf}, -    TcpStream, -}; +// TODO: logging (dropped errors) +pub mod client; +pub mod error; +pub mod jabber;  pub mod jid; +pub mod stanza; -pub struct JabberData { -    jid: jid::JID, -    password: String, -} - -impl JabberData { -    pub fn new(jid: JID, password: String) -> Self { -        Self { jid, password } -    } - -    async fn get_sockets(&self) -> Vec<SocketAddr> { -        let mut socket_addrs = Vec::new(); - -        // if it's a socket/ip then just return that - -        // socket -        if let Ok(socket_addr) = SocketAddr::from_str(&self.jid.domainpart) { -            socket_addrs.push(socket_addr); -            return socket_addrs; -        } -        // ip -        if let Ok(ip) = IpAddr::from_str(&self.jid.domainpart) { -            socket_addrs.push(SocketAddr::new(ip, 5222)); -            socket_addrs.push(SocketAddr::new(ip, 5223)); -            return socket_addrs; -        } - -        // if port specified return name resolutions with specified port - -        // otherwise resolve -        if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { -            if let Ok(lookup) = resolver -                .srv_lookup(format!("_xmpp-client._tcp.{}", self.jid.domainpart)) -                .await -            { -                for srv in lookup { -                    resolver -                        .lookup_ip(srv.target().to_owned()) -                        .await -                        .map(|ips| { -                            for ip in ips { -                                socket_addrs.push(SocketAddr::new(ip, srv.port())) -                            } -                        }); -                } -            } -            if let Ok(lookup) = resolver -                .srv_lookup(format!("_xmpps-client._tcp.{}", self.jid.domainpart)) -                .await -            { -                for srv in lookup { -                    resolver -                        .lookup_ip(srv.target().to_owned()) -                        .await -                        .map(|ips| { -                            for ip in ips { -                                socket_addrs.push(SocketAddr::new(ip, srv.port())) -                            } -                        }); -                } -            } - -            // in case cannot connect through SRV records -            resolver.lookup_ip(&self.jid.domainpart).await.map(|ips| { -                for ip in ips { -                    socket_addrs.push(SocketAddr::new(ip, 5222)); -                    socket_addrs.push(SocketAddr::new(ip, 5223)); -                } -            }); -        } - -        socket_addrs -    } -} - -pub struct Jabber { -    reader: Reader<OwnedReadHalf>, -    writer: Writer<OwnedWriteHalf>, -    data: JabberData, -} - -#[derive(Debug)] -pub enum JabberError { -    NotConnected, -} +pub use client::encrypted::JabberClient; +pub use error::JabberError; +pub use jabber::Jabber; +pub use jid::JID; -impl Jabber { -    pub async fn connect(data: JabberData) -> Result<Self, JabberError> { -        for socket_addr in data.get_sockets().await { -            println!("trying {}", socket_addr); -            if let Ok(stream) = TcpStream::connect(socket_addr).await { -                println!("connected to {}", socket_addr); -                let (read, write) = stream.into_split(); -                return Ok(Self { -                    reader: Reader::from_reader(read), -                    writer: Writer::new(write), -                    data, -                }); -            } -        } -        Err(JabberError::NotConnected) -    } - -    async fn reconnect(&mut self) { -        for socket_addr in self.data.get_sockets().await { -            println!("trying {}", socket_addr); -            if let Ok(stream) = TcpStream::connect(socket_addr).await { -                println!("connected to {}", socket_addr); -                let (read, write) = stream.into_split(); -                self.reader = Reader::from_reader(read); -                self.writer = Writer::new(write); -                return; -            } -        } -        println!("could not connect") -    } - -    async fn begin_stream(&mut self) -> Result<(), JabberError> { -        todo!() -    } - -    async fn starttls() -> Result<(), JabberError> { -        todo!() -    } - -    async fn directtls() -> Result<(), JabberError> { -        todo!() -    } - -    async fn auth(&mut self) -> Result<(), JabberError> { -        todo!() -    } - -    async fn close(&mut self) {} -} +pub type Result<T> = std::result::Result<T, JabberError>;  #[cfg(test)]  mod tests { -    use crate::jid::JID; +    use std::str::FromStr; -    use super::*; +    use crate::Jabber; +    use crate::JID; -    #[tokio::test] -    async fn get_sockets() { -        let data = JabberData::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); -        println!("{:?}", data.get_sockets().await) -    } +    // #[tokio::test] +    // async fn get_sockets() { +    //     let jabber = Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()); +    //     println!("{:?}", jabber.get_sockets().await) +    // }      #[tokio::test]      async fn connect() { -        Jabber::connect(JabberData::new( -            JID::from_str("cel@blos.sm").unwrap(), -            "password".to_owned(), -        )) -        .await -        .unwrap(); +        Jabber::new(JID::from_str("cel@blos.sm").unwrap(), "password".to_owned()) +            .connect() +            .await +            .unwrap() +            .ensure_tls() +            .await +            .unwrap() +            .start_stream() +            .await +            .unwrap();      }  } diff --git a/src/stanza/mod.rs b/src/stanza/mod.rs new file mode 100644 index 0000000..baf29e0 --- /dev/null +++ b/src/stanza/mod.rs @@ -0,0 +1 @@ +pub mod stream; diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs new file mode 100644 index 0000000..dde741d --- /dev/null +++ b/src/stanza/stream.rs @@ -0,0 +1,36 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +#[serde(rename = "stream:stream")] +struct Stream { +    #[serde(rename = "@from")] +    from: Option<String>, +    #[serde(rename = "@id")] +    id: Option<String>, +    #[serde(rename = "@to")] +    to: Option<String>, +    #[serde(rename = "@version")] +    version: Option<f32>, +    #[serde(rename = "@xml:lang")] +    lang: Option<String>, +    #[serde(rename = "@xmlns")] +    namespace: Option<String>, +    #[serde(rename = "@xmlns:stream")] +    stream_namespace: Option<String>, +} + +#[derive(Deserialize, Debug)] +#[serde(rename = "stream:features")] +pub struct StreamFeatures { +    #[serde(rename = "$value")] +    pub features: Vec<StreamFeature>, +} + +#[derive(Deserialize, PartialEq, Debug)] +pub enum StreamFeature { +    #[serde(rename = "starttls")] +    StartTls, +    // TODO: other stream features +    Sasl, +    Bind, +} | 
