diff options
Diffstat (limited to '')
-rw-r--r-- | .helix/languages.toml | 6 | ||||
-rw-r--r-- | luz/Cargo.toml | 10 | ||||
-rw-r--r-- | luz/src/client.rs | 134 | ||||
-rw-r--r-- | luz/src/client/tcp.rs | 110 | ||||
-rw-r--r-- | luz/src/client/ws.rs | 68 | ||||
-rw-r--r-- | luz/src/connection.rs | 170 | ||||
-rw-r--r-- | luz/src/connection/tcp.rs | 220 | ||||
-rw-r--r-- | luz/src/connection/ws.rs | 35 | ||||
-rw-r--r-- | luz/src/error.rs | 1 | ||||
-rw-r--r-- | luz/src/jabber_stream.rs | 250 | ||||
-rw-r--r-- | luz/src/jabber_stream/bound_stream.rs | 49 | ||||
-rw-r--r-- | luz/src/jabber_stream/tcp.rs | 103 | ||||
-rw-r--r-- | luz/src/jabber_stream/ws.rs | 105 |
13 files changed, 813 insertions, 448 deletions
diff --git a/.helix/languages.toml b/.helix/languages.toml index 8be248f..8e3e532 100644 --- a/.helix/languages.toml +++ b/.helix/languages.toml @@ -1,4 +1,8 @@ [language-server.rust-analyzer] command = "rust-analyzer" environment = { "DATABASE_URL" = "sqlite://filamento/filamento.db" } -config = { cargo.features = ["stanza/rfc_6121", "stanza/xep_0203", "stanza/xep_0030", "stanza/xep_0060", "stanza/xep_0172", "stanza/xep_0390", "stanza/xep_0128", "stanza/xep_0115", "stanza/xep_0084", "sqlx/sqlite", "sqlx/runtime-tokio", "sqlx/uuid", "sqlx/chrono", "jid/sqlx", "uuid/v4", "tokio/full", "rsasl/provider_base64", "rsasl/plain", "rsasl/config_builder", "rsasl/scram-sha-1"] } + # checkOnSave.overrideCommand = "cargo check --message-format=json -p luz", + # check.overrideCommand="cargo check --message-format=json -p luz", + # check.workspace = false, + # cargo.target = "wasm32-unknown-unknown", +config = { cargo.features = ["stanza/rfc_6121", "stanza/xep_0203", "stanza/rfc_7395", "stanza/xep_0030", "stanza/xep_0060", "stanza/xep_0172", "stanza/xep_0390", "stanza/xep_0128", "stanza/xep_0115", "stanza/xep_0084", "sqlx/sqlite", "sqlx/runtime-tokio", "sqlx/uuid", "sqlx/chrono", "jid/sqlx", "uuid/v4", "tokio/full", "rsasl/provider_base64", "rsasl/plain", "rsasl/config_builder", "rsasl/scram-sha-1"] } diff --git a/luz/Cargo.toml b/luz/Cargo.toml index d7207ca..709bc00 100644 --- a/luz/Cargo.toml +++ b/luz/Cargo.toml @@ -4,6 +4,9 @@ authors = ["cel <cel@bunny.garden>"] version = "0.1.0" edition = "2021" +[lib] +crate-type = ["cdylib", "rlib"] + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] @@ -23,6 +26,7 @@ tracing = "0.1.40" try_map = "0.3.1" stanza = { version = "0.1.0", path = "../stanza" } peanuts = { version = "0.1.0", git = "https://bunny.garden/peanuts" } +# peanuts = { version = "0.1.0", path = "../../peanuts" } jid = { version = "0.1.0", path = "../jid" } futures = "0.3.31" take_mut = "0.2.2" @@ -33,12 +37,18 @@ thiserror = "2.0.11" [target.'cfg(target_arch = "wasm32")'.dependencies] uuid = { version = "1.13.1", features = ["js", "v4"] } getrandom = { version = "0.2.15", features = ["js"] } +stanza = { version = "0.1.0", path = "../stanza", features = ["rfc_7395"] } +web-sys = { version = "0.3", features = ["Request", "WebSocket"] } +wasm-bindgen = "0.2" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio-native-tls = "0.3.1" trust-dns-resolver = "0.22.0" [dev-dependencies] +tracing-wasm = "0.2.1" +wasm-bindgen-test = "0.3.0" +tokio = { version = "1.28", features = ["macros", "rt", "time"] } test-log = { version = "0.2", features = ["trace"] } env_logger = "*" tracing-subscriber = { version = "0.3", default-features = false, features = [ diff --git a/luz/src/client.rs b/luz/src/client.rs index de2be08..ab0aa6d 100644 --- a/luz/src/client.rs +++ b/luz/src/client.rs @@ -1,111 +1,12 @@ -use rsasl::config::SASLConfig; -use stanza::{ - sasl::Mechanisms, - stream::{Feature, Features}, -}; +#[cfg(not(target_arch = "wasm32"))] +mod tcp; +#[cfg(target_arch = "wasm32")] +mod ws; -use crate::{ - connection::{Tls, Unencrypted}, - jabber_stream::bound_stream::BoundJabberStream, - Connection, Error, JabberStream, Result, JID, -}; - -pub async fn connect_and_login( - jid: &mut JID, - password: impl AsRef<str>, - server: &mut String, -) -> Result<BoundJabberStream<Tls>> { - let auth = SASLConfig::with_credentials( - None, - jid.localpart.clone().ok_or(Error::NoLocalpart)?, - password.as_ref().to_string(), - ) - .map_err(|e| Error::SASL(e.into()))?; - let mut conn_state = Connecting::start(&server).await?; - loop { - match conn_state { - Connecting::InsecureConnectionEstablised(tcp_stream) => { - conn_state = Connecting::InsecureStreamStarted( - JabberStream::start_stream(tcp_stream, server).await?, - ) - } - Connecting::InsecureStreamStarted(jabber_stream) => { - conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) - } - Connecting::InsecureGotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => { - conn_state = Connecting::StartTls(jabber_stream) - } - // TODO: better error - _ => return Err(Error::TlsRequired), - } - } - Connecting::StartTls(jabber_stream) => { - conn_state = - Connecting::ConnectionEstablished(jabber_stream.starttls(&server).await?) - } - Connecting::ConnectionEstablished(tls_stream) => { - conn_state = - Connecting::StreamStarted(JabberStream::start_stream(tls_stream, server).await?) - } - Connecting::StreamStarted(jabber_stream) => { - conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) - } - Connecting::GotFeatures((features, jabber_stream)) => { - match features.negotiate().ok_or(Error::Negotiation)? { - Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), - Feature::Sasl(mechanisms) => { - conn_state = Connecting::Sasl(mechanisms, jabber_stream) - } - Feature::Bind => conn_state = Connecting::Bind(jabber_stream), - Feature::Unknown => return Err(Error::Unsupported), - } - } - Connecting::Sasl(mechanisms, jabber_stream) => { - conn_state = Connecting::ConnectionEstablished( - jabber_stream.sasl(mechanisms, auth.clone()).await?, - ) - } - Connecting::Bind(jabber_stream) => { - return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); - } - } - } -} - -pub enum Connecting { - InsecureConnectionEstablised(Unencrypted), - InsecureStreamStarted(JabberStream<Unencrypted>), - InsecureGotFeatures((Features, JabberStream<Unencrypted>)), - StartTls(JabberStream<Unencrypted>), - ConnectionEstablished(Tls), - StreamStarted(JabberStream<Tls>), - GotFeatures((Features, JabberStream<Tls>)), - Sasl(Mechanisms, JabberStream<Tls>), - Bind(JabberStream<Tls>), -} - -impl Connecting { - pub async fn start(server: &str) -> Result<Self> { - match Connection::connect(server).await? { - Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished(tls_stream)), - Connection::Unencrypted(tcp_stream) => { - Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) - } - } - } -} - -pub enum InsecureConnecting { - Disconnected, - ConnectionEstablished(Connection), - PreStarttls(JabberStream<Unencrypted>), - PreAuthenticated(JabberStream<Tls>), - Authenticated(Tls), - PreBound(JabberStream<Tls>), - Bound(JabberStream<Tls>), -} +#[cfg(not(target_arch = "wasm32"))] +pub use tcp::connect_and_login; +#[cfg(target_arch = "wasm32")] +pub use ws::connect_and_login; #[cfg(test)] mod tests { @@ -122,20 +23,29 @@ mod tests { use test_log::test; use tokio::time::sleep; use tracing::info; + use wasm_bindgen_test::*; use super::connect_and_login; - #[test(tokio::test)] - async fn login() { + wasm_bindgen_test_configure!(run_in_browser); + + // #[test(tokio::test)] + #[wasm_bindgen_test] + async fn login() -> () { + tracing_wasm::set_as_global_default(); + let mut jid: JID = "test@blos.sm".try_into().unwrap(); - let _client = connect_and_login(&mut jid, "slayed", &mut "blos.sm".to_string()) + let _client = connect_and_login(&mut jid, "slayedaaa", &mut "blos.sm".to_string()) .await .unwrap(); - sleep(Duration::from_secs(5)).await + sleep(Duration::from_secs(5)).await; } - #[test(tokio::test)] + #[wasm_bindgen_test] + // #[test(tokio::test)] async fn ping_parallel() { + tracing_wasm::set_as_global_default(); + let mut jid: JID = "test@blos.sm".try_into().unwrap(); let mut server = "blos.sm".to_string(); let client = connect_and_login(&mut jid, "slayed", &mut server) diff --git a/luz/src/client/tcp.rs b/luz/src/client/tcp.rs new file mode 100644 index 0000000..4e35ef0 --- /dev/null +++ b/luz/src/client/tcp.rs @@ -0,0 +1,110 @@ +use rsasl::config::SASLConfig; +use stanza::{ + sasl::Mechanisms, + stream::{Feature, Features}, +}; + +use crate::{ + connection::Unencrypted, jabber_stream::bound_stream::BoundJabberStream, Connection, Error, + JabberStream, Result, JID, +}; + +pub async fn connect_and_login( + jid: &mut JID, + password: impl AsRef<str>, + server: &mut String, +) -> Result<BoundJabberStream> { + let auth = SASLConfig::with_credentials( + None, + jid.localpart.clone().ok_or(Error::NoLocalpart)?, + password.as_ref().to_string(), + ) + .map_err(|e| Error::SASL(e.into()))?; + let mut conn_state = Connecting::start(&server).await?; + loop { + match conn_state { + Connecting::InsecureConnectionEstablised(tcp_stream) => { + conn_state = Connecting::InsecureStreamStarted( + JabberStream::start_stream(Connection::Unencrypted(tcp_stream), server).await?, + ) + } + Connecting::InsecureStreamStarted(jabber_stream) => { + conn_state = Connecting::InsecureGotFeatures(jabber_stream.get_features().await?) + } + Connecting::InsecureGotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => { + conn_state = Connecting::StartTls(jabber_stream) + } + // TODO: better error + _ => return Err(Error::TlsRequired), + } + } + Connecting::StartTls(jabber_stream) => { + conn_state = Connecting::ConnectionEstablished(Connection::Encrypted( + jabber_stream.starttls(&server).await?, + )) + } + Connecting::ConnectionEstablished(connection) => { + conn_state = + Connecting::StreamStarted(JabberStream::start_stream(connection, server).await?) + } + Connecting::StreamStarted(jabber_stream) => { + conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) + } + Connecting::GotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), + Feature::Sasl(mechanisms) => { + conn_state = Connecting::Sasl(mechanisms, jabber_stream) + } + Feature::Bind => conn_state = Connecting::Bind(jabber_stream), + Feature::Unknown => return Err(Error::Unsupported), + } + } + Connecting::Sasl(mechanisms, jabber_stream) => { + conn_state = Connecting::ConnectionEstablished( + jabber_stream.sasl(mechanisms, auth.clone()).await?, + ) + } + Connecting::Bind(jabber_stream) => { + return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); + } + } + } +} + +pub enum Connecting { + InsecureConnectionEstablised(Unencrypted), + InsecureStreamStarted(JabberStream), + InsecureGotFeatures((Features, JabberStream)), + StartTls(JabberStream), + ConnectionEstablished(Connection), + StreamStarted(JabberStream), + GotFeatures((Features, JabberStream)), + Sasl(Mechanisms, JabberStream), + Bind(JabberStream), +} + +impl Connecting { + pub async fn start(server: &str) -> Result<Self> { + match Connection::connect(server).await? { + Connection::Encrypted(tls_stream) => Ok(Connecting::ConnectionEstablished( + Connection::Encrypted(tls_stream), + )), + Connection::Unencrypted(tcp_stream) => { + Ok(Connecting::InsecureConnectionEstablised(tcp_stream)) + } + } + } +} + +// pub enum InsecureConnecting { +// Disconnected, +// ConnectionEstablished(Connection), +// PreStarttls(JabberStream<Unencrypted>), +// PreAuthenticated(JabberStream<Tls>), +// Authenticated(Tls), +// PreBound(JabberStream<Tls>), +// Bound(JabberStream<Tls>), +// } diff --git a/luz/src/client/ws.rs b/luz/src/client/ws.rs new file mode 100644 index 0000000..ecb64cb --- /dev/null +++ b/luz/src/client/ws.rs @@ -0,0 +1,68 @@ +use rsasl::config::SASLConfig; +use stanza::{ + sasl::Mechanisms, + stream::{Feature, Features}, +}; + +use crate::{ + connection::Ws, jabber_stream::bound_stream::BoundJabberStream, Connection, Error, + JabberStream, Result, JID, +}; + +pub async fn connect_and_login( + jid: &mut JID, + password: impl AsRef<str>, + server: &mut String, +) -> Result<BoundJabberStream> { + let auth = SASLConfig::with_credentials( + None, + jid.localpart.clone().ok_or(Error::NoLocalpart)?, + password.as_ref().to_string(), + ) + .map_err(|e| Error::SASL(e.into()))?; + let mut conn_state = Connecting::start(&server).await?; + loop { + match conn_state { + Connecting::ConnectionEstablished(ws) => { + conn_state = + Connecting::StreamStarted(JabberStream::start_stream(ws, server).await?) + } + Connecting::StreamStarted(jabber_stream) => { + conn_state = Connecting::GotFeatures(jabber_stream.get_features().await?) + } + Connecting::GotFeatures((features, jabber_stream)) => { + match features.negotiate().ok_or(Error::Negotiation)? { + Feature::StartTls(_start_tls) => return Err(Error::AlreadyTls), + Feature::Sasl(mechanisms) => { + conn_state = Connecting::Sasl(mechanisms, jabber_stream) + } + Feature::Bind => conn_state = Connecting::Bind(jabber_stream), + Feature::Unknown => return Err(Error::Unsupported), + } + } + Connecting::Sasl(mechanisms, jabber_stream) => { + conn_state = Connecting::ConnectionEstablished( + jabber_stream.sasl(mechanisms, auth.clone()).await?, + ) + } + Connecting::Bind(jabber_stream) => { + return Ok(jabber_stream.bind(jid).await?.to_bound_jabber()); + } + } + } +} + +pub enum Connecting { + ConnectionEstablished(Connection), + StreamStarted(JabberStream), + GotFeatures((Features, JabberStream)), + Sasl(Mechanisms, JabberStream), + Bind(JabberStream), +} + +impl Connecting { + pub async fn start(server: &str) -> Result<Self> { + let connection = Connection::connect(server).await?; + Ok(Connecting::ConnectionEstablished(Connection(connection.0))) + } +} diff --git a/luz/src/connection.rs b/luz/src/connection.rs index b185eca..d0bdadc 100644 --- a/luz/src/connection.rs +++ b/luz/src/connection.rs @@ -1,164 +1,12 @@ -use std::net::{IpAddr, SocketAddr}; -use std::str; -use std::str::FromStr; - -use tokio::net::TcpStream; -use tokio_native_tls::native_tls::TlsConnector; -// TODO: use rustls -use tokio_native_tls::TlsStream; -use tracing::{debug, info, instrument, trace}; - -use crate::Result; -use crate::{Error, JID}; - -pub type Tls = TlsStream<TcpStream>; -pub type Unencrypted = TcpStream; - -#[derive(Debug)] -pub enum Connection { - Encrypted(Tls), - Unencrypted(Unencrypted), -} - -impl Connection { - // #[instrument] - /// stream not started - // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { - // match self { - // Connection::Encrypted(j) => Ok(j), - // Connection::Unencrypted(mut j) => { - // j.start_stream().await?; - // info!("upgrading connection to tls"); - // j.get_features().await?; - // let j = j.starttls().await?; - // Ok(j) - // } - // } - // } - - pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> { - let jid: JID = JID::from_str(jid.as_ref())?; - let server = jid.domainpart.clone(); - Self::connect(&server).await - } - - #[instrument] - pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> { - info!("connecting to {}", server.as_ref()); - let sockets = Self::get_sockets(server.as_ref()).await; - debug!("discovered sockets: {:?}", sockets); - for (socket_addr, tls) in sockets { - match tls { - true => { - if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await { - info!("connected via encrypted stream to {}", socket_addr); - // let (readhalf, writehalf) = tokio::io::split(connection); - return Ok(Self::Encrypted(connection)); - } - } - false => { - if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { - info!("connected via unencrypted stream to {}", socket_addr); - // let (readhalf, writehalf) = tokio::io::split(connection); - return Ok(Self::Unencrypted(connection)); - } - } - } - } - Err(Error::Connection) - } - - #[instrument] - async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> { - let mut socket_addrs = Vec::new(); - - // if it's a socket/ip then just return that - - // socket - trace!("checking if address is a socket address"); - if let Ok(socket_addr) = SocketAddr::from_str(address) { - debug!("{} is a socket address", address); - match socket_addr.port() { - 5223 => socket_addrs.push((socket_addr, true)), - _ => socket_addrs.push((socket_addr, false)), - } - - return socket_addrs; - } - // ip - trace!("checking if address is an ip"); - if let Ok(ip) = IpAddr::from_str(address) { - debug!("{} is an ip", address); - socket_addrs.push((SocketAddr::new(ip, 5222), false)); - socket_addrs.push((SocketAddr::new(ip, 5223), true)); - return socket_addrs; - } - - // otherwise resolve - debug!("resolving {}", address); - if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { - if let Ok(lookup) = resolver - .srv_lookup(format!("_xmpp-client._tcp.{}", address)) - .await - { - for srv in lookup { - resolver - .lookup_ip(srv.target().to_owned()) - .await - .map(|ips| { - for ip in ips { - socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) - } - }); - } - } - if let Ok(lookup) = resolver - .srv_lookup(format!("_xmpps-client._tcp.{}", address)) - .await - { - for srv in lookup { - resolver - .lookup_ip(srv.target().to_owned()) - .await - .map(|ips| { - for ip in ips { - socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) - } - }); - } - } - - // in case cannot connect through SRV records - resolver.lookup_ip(address).await.map(|ips| { - for ip in ips { - socket_addrs.push((SocketAddr::new(ip, 5222), false)); - socket_addrs.push((SocketAddr::new(ip, 5223), true)); - } - }); - } - socket_addrs - } - - /// establishes a connection to the server - #[instrument] - pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { - let socket = TcpStream::connect(socket_addr) - .await - .map_err(|_| Error::Connection)?; - let connector = TlsConnector::new().map_err(|_| Error::Connection)?; - tokio_native_tls::TlsConnector::from(connector) - .connect(domain_name, socket) - .await - .map_err(|_| Error::Connection) - } - - #[instrument] - pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { - TcpStream::connect(socket_addr) - .await - .map_err(|_| Error::Connection) - } -} +#[cfg(not(target_arch = "wasm32"))] +mod tcp; +#[cfg(target_arch = "wasm32")] +mod ws; + +#[cfg(not(target_arch = "wasm32"))] +pub use tcp::{Connection, Tls, Unencrypted}; +#[cfg(target_arch = "wasm32")] +pub use ws::{Connection, Ws}; #[cfg(test)] mod tests { diff --git a/luz/src/connection/tcp.rs b/luz/src/connection/tcp.rs new file mode 100644 index 0000000..a9e81c3 --- /dev/null +++ b/luz/src/connection/tcp.rs @@ -0,0 +1,220 @@ +use std::net::{IpAddr, SocketAddr}; +use std::pin::pin; +use std::str; +use std::str::FromStr; +use std::task::ready; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; +use tokio_native_tls::native_tls::TlsConnector; +// TODO: use rustls +use tokio_native_tls::TlsStream; +use tracing::{debug, info, instrument, trace}; + +use crate::Result; +use crate::{Error, JID}; + +pub type Tls = TlsStream<TcpStream>; +pub type Unencrypted = TcpStream; + +#[derive(Debug)] +pub enum Connection { + Encrypted(Tls), + Unencrypted(Unencrypted), +} + +impl AsyncRead for Connection { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll<std::io::Result<()>> { + match self.get_mut() { + Connection::Encrypted(tls_stream) => ready!(pin!(tls_stream).poll_read(cx, buf)).into(), + Connection::Unencrypted(tcp_stream) => { + ready!(pin!(tcp_stream).poll_read(cx, buf)).into() + } + } + } +} + +impl Unpin for Connection {} + +impl AsyncWrite for Connection { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll<std::result::Result<usize, std::io::Error>> { + match self.get_mut() { + Connection::Encrypted(tls_stream) => { + ready!(pin!(tls_stream).poll_write(cx, buf)).into() + } + Connection::Unencrypted(tcp_stream) => { + ready!(pin!(tcp_stream).poll_write(cx, buf)).into() + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<std::result::Result<(), std::io::Error>> { + match self.get_mut() { + Connection::Encrypted(tls_stream) => ready!(pin!(tls_stream).poll_flush(cx)).into(), + Connection::Unencrypted(tcp_stream) => ready!(pin!(tcp_stream).poll_flush(cx)).into(), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<std::result::Result<(), std::io::Error>> { + match self.get_mut() { + Connection::Encrypted(tls_stream) => ready!(pin!(tls_stream).poll_shutdown(cx)).into(), + Connection::Unencrypted(tcp_stream) => { + ready!(pin!(tcp_stream).poll_shutdown(cx)).into() + } + } + } +} + +impl Connection { + // #[instrument] + /// stream not started + // pub async fn ensure_tls(self) -> Result<Jabber<Tls>> { + // match self { + // Connection::Encrypted(j) => Ok(j), + // Connection::Unencrypted(mut j) => { + // j.start_stream().await?; + // info!("upgrading connection to tls"); + // j.get_features().await?; + // let j = j.starttls().await?; + // Ok(j) + // } + // } + // } + + pub async fn connect_user(jid: impl AsRef<str>) -> Result<Self> { + let jid: JID = JID::from_str(jid.as_ref())?; + let server = jid.domainpart.clone(); + Self::connect(&server).await + } + + #[instrument] + pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> { + info!("connecting to {}", server.as_ref()); + let sockets = Self::get_sockets(server.as_ref()).await; + debug!("discovered sockets: {:?}", sockets); + for (socket_addr, tls) in sockets { + match tls { + true => { + if let Ok(connection) = Self::connect_tls(socket_addr, server.as_ref()).await { + info!("connected via encrypted stream to {}", socket_addr); + // let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Encrypted(connection)); + } + } + false => { + if let Ok(connection) = Self::connect_unencrypted(socket_addr).await { + info!("connected via unencrypted stream to {}", socket_addr); + // let (readhalf, writehalf) = tokio::io::split(connection); + return Ok(Self::Unencrypted(connection)); + } + } + } + } + Err(Error::Connection) + } + + #[instrument] + async fn get_sockets(address: &str) -> Vec<(SocketAddr, bool)> { + let mut socket_addrs = Vec::new(); + + // if it's a socket/ip then just return that + + // socket + trace!("checking if address is a socket address"); + if let Ok(socket_addr) = SocketAddr::from_str(address) { + debug!("{} is a socket address", address); + match socket_addr.port() { + 5223 => socket_addrs.push((socket_addr, true)), + _ => socket_addrs.push((socket_addr, false)), + } + + return socket_addrs; + } + // ip + trace!("checking if address is an ip"); + if let Ok(ip) = IpAddr::from_str(address) { + debug!("{} is an ip", address); + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + return socket_addrs; + } + + // otherwise resolve + debug!("resolving {}", address); + if let Ok(resolver) = trust_dns_resolver::AsyncResolver::tokio_from_system_conf() { + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpp-client._tcp.{}", address)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), false)) + } + }); + } + } + if let Ok(lookup) = resolver + .srv_lookup(format!("_xmpps-client._tcp.{}", address)) + .await + { + for srv in lookup { + resolver + .lookup_ip(srv.target().to_owned()) + .await + .map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, srv.port()), true)) + } + }); + } + } + + // in case cannot connect through SRV records + resolver.lookup_ip(address).await.map(|ips| { + for ip in ips { + socket_addrs.push((SocketAddr::new(ip, 5222), false)); + socket_addrs.push((SocketAddr::new(ip, 5223), true)); + } + }); + } + socket_addrs + } + + /// establishes a connection to the server + #[instrument] + pub async fn connect_tls(socket_addr: SocketAddr, domain_name: &str) -> Result<Tls> { + let socket = TcpStream::connect(socket_addr) + .await + .map_err(|_| Error::Connection)?; + let connector = TlsConnector::new().map_err(|_| Error::Connection)?; + tokio_native_tls::TlsConnector::from(connector) + .connect(domain_name, socket) + .await + .map_err(|_| Error::Connection) + } + + #[instrument] + pub async fn connect_unencrypted(socket_addr: SocketAddr) -> Result<Unencrypted> { + TcpStream::connect(socket_addr) + .await + .map_err(|_| Error::Connection) + } +} diff --git a/luz/src/connection/ws.rs b/luz/src/connection/ws.rs new file mode 100644 index 0000000..fc983b1 --- /dev/null +++ b/luz/src/connection/ws.rs @@ -0,0 +1,35 @@ +use tokio::sync::mpsc; +use wasm_bindgen::closure::Closure; +use wasm_bindgen::JsCast; +use web_sys::WebSocket; + +use crate::Result; + +pub type Ws = WebSocket; + +#[derive(Debug)] +pub struct Connection(pub Ws); + +impl Connection { + pub async fn connect(server: impl AsRef<str> + std::fmt::Debug) -> Result<Self> { + // TODO: get the connection url here + let ws = WebSocket::new_with_str("wss://xmpp.bunny.garden/ws", "xmpp").unwrap(); + let (send, mut recv) = mpsc::unbounded_channel(); + let onopen = Closure::<dyn FnMut()>::new(Box::new(move || { + tracing::info!("socket opened"); + let send = send.clone(); + match send.send(()) { + Ok(()) => (), + Err(e) => { + tracing::error!("socket opened notify: {:?}", e); + } + } + }) as Box<dyn FnMut()>); + ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); + + recv.recv().await.unwrap(); + + // TODO: check reply if it's xmpp too + Ok(Self(ws)) + } +} diff --git a/luz/src/error.rs b/luz/src/error.rs index ec60778..6b08c15 100644 --- a/luz/src/error.rs +++ b/luz/src/error.rs @@ -10,6 +10,7 @@ use thiserror::Error; #[derive(Error, Debug, Clone)] pub enum Error { + #[cfg(not(target_arch = "wasm32"))] #[error("connection")] Connection, #[error("utf8 decode: {0}")] diff --git a/luz/src/jabber_stream.rs b/luz/src/jabber_stream.rs index ef21921..4332366 100644 --- a/luz/src/jabber_stream.rs +++ b/luz/src/jabber_stream.rs @@ -1,135 +1,53 @@ +pub mod bound_stream; +#[cfg(not(target_arch = "wasm32"))] +mod tcp; +#[cfg(target_arch = "wasm32")] +mod ws; + use std::str::{self, FromStr}; use std::sync::Arc; use jid::JID; use peanuts::element::IntoElement; +#[cfg(target_arch = "wasm32")] +use peanuts::reader::WebSocketOnMessageRead; use peanuts::{Reader, Writer}; use rsasl::prelude::{Mechname, SASLClient, SASLConfig}; use stanza::bind::{Bind, BindType, FullJidType, ResourceType}; use stanza::client::iq::{Iq, IqType, Query}; use stanza::client::Stanza; +#[cfg(target_arch = "wasm32")] +use stanza::rfc_7395::Open; use stanza::sasl::{Auth, Challenge, Mechanisms, Response, ServerResponse}; use stanza::starttls::{Proceed, StartTls}; use stanza::stream::{Features, Stream}; use stanza::XML_VERSION; -use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf}; +#[cfg(not(target_arch = "wasm32"))] use tokio_native_tls::native_tls::TlsConnector; use tracing::{debug, instrument}; - -use crate::connection::{Tls, Unencrypted}; +#[cfg(target_arch = "wasm32")] +use web_sys::{wasm_bindgen::JsCast, WebSocket}; + +use crate::connection::Connection; +#[cfg(not(target_arch = "wasm32"))] +use crate::connection::Tls; +#[cfg(target_arch = "wasm32")] +use crate::connection::Ws; use crate::error::Error; use crate::Result; -pub mod bound_stream; - -// open stream (streams started) -pub struct JabberStream<S> { - reader: JabberReader<S>, - writer: JabberWriter<S>, -} - -impl<S> JabberStream<S> { - fn split(self) -> (JabberReader<S>, JabberWriter<S>) { - let reader = self.reader; - let writer = self.writer; - (reader, writer) - } -} - -pub struct JabberReader<S>(Reader<ReadHalf<S>>); - -impl<S> JabberReader<S> { - // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner - fn new(reader: Reader<ReadHalf<S>>) -> Self { - Self(reader) - } - - fn unsplit(self, writer: JabberWriter<S>) -> JabberStream<S> { - JabberStream { - reader: self, - writer, - } - } - - fn into_inner(self) -> Reader<ReadHalf<S>> { - self.0 - } -} - -impl<S> JabberReader<S> -where - S: AsyncRead + Unpin, -{ - pub async fn try_close(&mut self) -> Result<()> { - self.read_end_tag().await?; - Ok(()) - } -} - -impl<S> std::ops::Deref for JabberReader<S> { - type Target = Reader<ReadHalf<S>>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl<S> std::ops::DerefMut for JabberReader<S> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -pub struct JabberWriter<S>(Writer<WriteHalf<S>>); - -impl<S> JabberWriter<S> { - fn new(writer: Writer<WriteHalf<S>>) -> Self { - Self(writer) - } - - fn unsplit(self, reader: JabberReader<S>) -> JabberStream<S> { - JabberStream { - reader, - writer: self, - } - } - - fn into_inner(self) -> Writer<WriteHalf<S>> { - self.0 - } -} - -impl<S> JabberWriter<S> -where - S: AsyncWrite + Unpin + Send, -{ - pub async fn try_close(&mut self) -> Result<()> { - self.write_end().await?; - Ok(()) - } -} - -impl<S> std::ops::Deref for JabberWriter<S> { - type Target = Writer<WriteHalf<S>>; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} +#[cfg(not(target_arch = "wasm32"))] +pub use tcp::{JabberReader, JabberStream, JabberWriter}; +#[cfg(target_arch = "wasm32")] +pub use ws::{JabberReader, JabberStream, JabberWriter}; -impl<S> std::ops::DerefMut for JabberWriter<S> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl<S> JabberStream<S> -where - S: AsyncRead + AsyncWrite + Unpin + Send + std::fmt::Debug, - JabberStream<S>: std::fmt::Debug, -{ +impl JabberStream { #[instrument] - pub async fn sasl(mut self, mechanisms: Mechanisms, sasl_config: Arc<SASLConfig>) -> Result<S> { + pub async fn sasl( + mut self, + mechanisms: Mechanisms, + sasl_config: Arc<SASLConfig>, + ) -> Result<Connection> { let sasl = SASLClient::new(sasl_config); let mut offered_mechs: Vec<&Mechname> = Vec::new(); for mechanism in &mechanisms.mechanisms { @@ -212,10 +130,18 @@ where } } } - let writer = self.writer.into_inner().into_inner(); - let reader = self.reader.into_inner().into_inner(); - let stream = reader.unsplit(writer); - Ok(stream) + #[cfg(not(target_arch = "wasm32"))] + { + let writer = self.writer.into_inner().into_inner(); + let reader = self.reader.into_inner().into_inner(); + let stream = reader.unsplit(writer); + Ok(stream) + } + #[cfg(target_arch = "wasm32")] + { + let writer = self.writer.into_inner().into_inner(); + Ok(Connection(writer)) + } } #[instrument] @@ -312,8 +238,45 @@ where } } + #[cfg(target_arch = "wasm32")] + #[instrument] + pub async fn start_stream(connection: Connection, server: &mut String) -> Result<Self> { + // client to server + let writer = connection.0; + let (onmessage, reader) = WebSocketOnMessageRead::new(); + writer.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + onmessage.forget(); + let mut reader = JabberReader::new(Reader::new_unendable(reader)); + let mut writer = JabberWriter::new(Writer::<WebSocket>::new_unendable(writer)); + + // open element + let open = Open { + from: None, + id: None, + lang: None, + version: Some("1.0".to_string()), + to: Some(JID::from_str(server.as_ref())?), + }; + writer.write(&open).await?; + + // server to client + + // may or may not send a declaration + let _decl = reader.read_prolog().await?; + + // receive open element and validate + let open: Open = reader.read().await?; + debug!("got open: {:?}", open); + if let Some(from) = open.from { + *server = from.to_string(); + } + + Ok(Self { reader, writer }) + } + + #[cfg(not(target_arch = "wasm32"))] #[instrument] - pub async fn start_stream(connection: S, server: &mut String) -> Result<Self> { + pub async fn start_stream(connection: Connection, server: &mut String) -> Result<Self> { // client to server let (reader, writer) = tokio::io::split(connection); let mut reader = JabberReader::new(Reader::new(reader)); @@ -354,20 +317,28 @@ where Ok((features, self)) } - pub fn into_inner(self) -> S { - self.reader + pub fn into_inner(self) -> Connection { + let writer = self.writer.into_inner(); + + #[cfg(not(target_arch = "wasm32"))] + let output = self + .reader .into_inner() .into_inner() - .unsplit(self.writer.into_inner().into_inner()) + .unsplit(writer.into_inner()); + + #[cfg(target_arch = "wasm32")] + let output = Connection(writer.into_inner()); + + output } pub async fn send_stanza(&mut self, stanza: &Stanza) -> Result<()> { self.writer.write(stanza).await?; Ok(()) } -} -impl JabberStream<Unencrypted> { + #[cfg(not(target_arch = "wasm32"))] #[instrument] pub async fn starttls(mut self, domain: impl AsRef<str> + std::fmt::Debug) -> Result<Tls> { self.writer @@ -381,34 +352,23 @@ impl JabberStream<Unencrypted> { .into_inner() .into_inner() .unsplit(self.writer.into_inner().into_inner()); - if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) - .connect(domain.as_ref(), stream) - .await - { - return Ok(tls_stream); - } else { - return Err(Error::Connection); + match stream { + Connection::Encrypted(_tls_stream) => return Err(Error::AlreadyTls), + Connection::Unencrypted(tcp_stream) => { + if let Ok(tls_stream) = tokio_native_tls::TlsConnector::from(connector) + .connect(domain.as_ref(), tcp_stream) + .await + { + return Ok(tls_stream); + } else { + return Err(Error::Connection); + } + } } } } - -impl std::fmt::Debug for JabberStream<Tls> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"tls") - .finish() - } -} - -impl std::fmt::Debug for JabberStream<Unencrypted> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Jabber") - .field("connection", &"unencrypted") - .finish() - } -} - #[cfg(test)] + mod tests { use test_log::test; diff --git a/luz/src/jabber_stream/bound_stream.rs b/luz/src/jabber_stream/bound_stream.rs index 25b79ff..13f8b2d 100644 --- a/luz/src/jabber_stream/bound_stream.rs +++ b/luz/src/jabber_stream/bound_stream.rs @@ -4,84 +4,75 @@ use tokio::io::{AsyncRead, AsyncWrite}; use super::{JabberReader, JabberStream, JabberWriter}; -pub struct BoundJabberStream<S>(JabberStream<S>); +pub struct BoundJabberStream(JabberStream); -impl<S> Deref for BoundJabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send, -{ - type Target = JabberStream<S>; +impl Deref for BoundJabberStream { + type Target = JabberStream; fn deref(&self) -> &Self::Target { &self.0 } } -impl<S> DerefMut for BoundJabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send, -{ +impl DerefMut for BoundJabberStream { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -impl<S> BoundJabberStream<S> { - pub fn split(self) -> (BoundJabberReader<S>, BoundJabberWriter<S>) { +impl BoundJabberStream { + pub fn split(self) -> (BoundJabberReader, BoundJabberWriter) { let (reader, writer) = self.0.split(); (BoundJabberReader(reader), BoundJabberWriter(writer)) } } -pub struct BoundJabberReader<S>(JabberReader<S>); +pub struct BoundJabberReader(JabberReader); -impl<S> BoundJabberReader<S> { - pub fn unsplit(self, writer: BoundJabberWriter<S>) -> BoundJabberStream<S> { +impl BoundJabberReader { + pub fn unsplit(self, writer: BoundJabberWriter) -> BoundJabberStream { BoundJabberStream(self.0.unsplit(writer.0)) } } -impl<S> std::ops::Deref for BoundJabberReader<S> { - type Target = JabberReader<S>; +impl std::ops::Deref for BoundJabberReader { + type Target = JabberReader; fn deref(&self) -> &Self::Target { &self.0 } } -impl<S> std::ops::DerefMut for BoundJabberReader<S> { +impl std::ops::DerefMut for BoundJabberReader { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -pub struct BoundJabberWriter<S>(JabberWriter<S>); +pub struct BoundJabberWriter(JabberWriter); -impl<S> BoundJabberWriter<S> { - pub fn unsplit(self, reader: BoundJabberReader<S>) -> BoundJabberStream<S> { +impl BoundJabberWriter { + pub fn unsplit(self, reader: BoundJabberReader) -> BoundJabberStream { BoundJabberStream(self.0.unsplit(reader.0)) } } -impl<S> std::ops::Deref for BoundJabberWriter<S> { - type Target = JabberWriter<S>; +impl std::ops::Deref for BoundJabberWriter { + type Target = JabberWriter; fn deref(&self) -> &Self::Target { &self.0 } } -impl<S> std::ops::DerefMut for BoundJabberWriter<S> { +impl std::ops::DerefMut for BoundJabberWriter { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } } -impl<S> JabberStream<S> -where - S: AsyncWrite + AsyncRead + Unpin + Send, -{ - pub fn to_bound_jabber(self) -> BoundJabberStream<S> { +impl JabberStream { + pub fn to_bound_jabber(self) -> BoundJabberStream { BoundJabberStream(self) } } diff --git a/luz/src/jabber_stream/tcp.rs b/luz/src/jabber_stream/tcp.rs new file mode 100644 index 0000000..77305e3 --- /dev/null +++ b/luz/src/jabber_stream/tcp.rs @@ -0,0 +1,103 @@ +use peanuts::loggable::Loggable; +use peanuts::{Reader, Writer}; +use tokio::io::{ReadHalf, WriteHalf}; + +use crate::{Connection, Result}; + +// open stream (streams started) +#[derive(Debug)] +pub struct JabberStream { + pub(super) reader: JabberReader, + pub(super) writer: JabberWriter, +} + +impl JabberStream { + pub fn split(self) -> (JabberReader, JabberWriter) { + let reader = self.reader; + let writer = self.writer; + (reader, writer) + } +} + +#[derive(Debug)] +pub struct JabberReader(Reader<ReadHalf<Connection>>); + +impl JabberReader { + // TODO: consider taking a readhalf and creating peanuts::Reader here, only one inner + pub fn new(reader: Reader<ReadHalf<Connection>>) -> Self { + Self(reader) + } + + pub fn unsplit(self, writer: JabberWriter) -> JabberStream { + JabberStream { + reader: self, + writer, + } + } + + pub fn into_inner(self) -> Reader<ReadHalf<Connection>> { + self.0 + } +} + +impl JabberReader { + pub async fn try_close(&mut self) -> Result<()> { + self.read_end_tag().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberReader { + type Target = Reader<ReadHalf<Connection>>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Debug)] +pub struct JabberWriter(Writer<Loggable<WriteHalf<Connection>>>); + +impl JabberWriter { + pub fn new(writer: Writer<Loggable<WriteHalf<Connection>>>) -> Self { + Self(writer) + } + + pub fn unsplit(self, reader: JabberReader) -> JabberStream { + JabberStream { + reader, + writer: self, + } + } + + pub fn into_inner(self) -> Writer<Loggable<WriteHalf<Connection>>> { + self.0 + } +} + +impl JabberWriter { + pub async fn try_close(&mut self) -> Result<()> { + self.write_end().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberWriter { + type Target = Writer<Loggable<WriteHalf<Connection>>>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/luz/src/jabber_stream/ws.rs b/luz/src/jabber_stream/ws.rs new file mode 100644 index 0000000..35b6f60 --- /dev/null +++ b/luz/src/jabber_stream/ws.rs @@ -0,0 +1,105 @@ +use peanuts::loggable::Loggable; +use peanuts::reader::WebSocketOnMessageRead; +use peanuts::{Reader, Writer}; +use stanza::rfc_7395::Close; +use web_sys::WebSocket; + +use crate::{Connection, Result}; + +// open stream (streams started) +#[derive(Debug)] +pub struct JabberStream { + pub(super) reader: JabberReader, + pub(super) writer: JabberWriter, +} + +impl JabberStream { + pub fn split(self) -> (JabberReader, JabberWriter) { + let reader = self.reader; + let writer = self.writer; + (reader, writer) + } +} + +#[derive(Debug)] +pub struct JabberReader(Reader<WebSocketOnMessageRead>); + +impl JabberReader { + pub fn new(reader: Reader<WebSocketOnMessageRead>) -> Self { + Self(reader) + } + + pub fn unsplit(self, writer: JabberWriter) -> JabberStream { + JabberStream { + reader: self, + writer, + } + } + + pub fn into_inner(self) -> Reader<WebSocketOnMessageRead> { + self.0 + } +} + +impl JabberReader { + pub async fn try_close(&mut self) -> Result<()> { + let close: Close = self.read().await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberReader { + type Target = Reader<WebSocketOnMessageRead>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberReader { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[derive(Debug)] +pub struct JabberWriter(Writer<WebSocket>); + +impl JabberWriter { + pub fn new(writer: Writer<WebSocket>) -> Self { + Self(writer) + } + + pub fn unsplit(self, reader: JabberReader) -> JabberStream { + JabberStream { + reader, + writer: self, + } + } + + pub fn into_inner(self) -> Writer<WebSocket> { + self.0 + } +} + +impl JabberWriter { + pub async fn try_close(&mut self) -> Result<()> { + let close = Close::default(); + self.write(&close).await?; + Ok(()) + } +} + +impl std::ops::Deref for JabberWriter { + type Target = Writer<WebSocket>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for JabberWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} |