use std::str; use std::sync::Arc; use quick_xml::{events::Event, se::Serializer, NsReader, Writer}; use rsasl::prelude::SASLConfig; use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; use tracing::{debug, info, trace}; use crate::connection::{Tls, Unencrypted}; use crate::error::JabberError; use crate::stanza::stream::Stream; use crate::stanza::DECLARATION; use crate::Result; use crate::JID; pub struct Jabber where S: AsyncRead + AsyncWrite + Unpin, { reader: NsReader>>, writer: WriteHalf, jid: Option, auth: Option>, server: String, } impl Jabber where S: AsyncRead + AsyncWrite + Unpin, { pub fn new( reader: ReadHalf, writer: WriteHalf, jid: Option, auth: Option>, server: String, ) -> Self { let reader = NsReader::from_reader(BufReader::new(reader)); Self { reader, writer, jid, auth, server, } } } impl Jabber where S: AsyncRead + AsyncWrite + Unpin, { // pub async fn negotiate(self) -> Result> {} pub async fn start_stream(&mut self) -> Result<()> { // client to server // declaration let mut xmlwriter = Writer::new(&mut self.writer); xmlwriter.write_event_async(DECLARATION.clone()).await?; // opening stream element let server = &self.server.to_owned().try_into()?; let stream_element = Stream::new_client(None, server, None, "en"); // TODO: nicer function to serialize to xml writer let mut buffer = String::new(); let ser = Serializer::with_root(&mut buffer, Some("stream:stream")).expect("stream name"); stream_element.serialize(ser).unwrap(); trace!("sent: {}", buffer); self.writer.write_all(buffer.as_bytes()).await.unwrap(); // server to client // may or may not send a declaration let mut buf = Vec::new(); let mut first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; trace!("received: {:?}", first_event); match first_event { (quick_xml::name::ResolveResult::Unbound, Event::Decl(e)) => { if let Ok(version) = e.version() { if version.as_ref() == b"1.0" { first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; trace!("received: {:?}", first_event); } else { // todo: error todo!() } } else { first_event = self.reader.read_resolved_event_into_async(&mut buf).await?; trace!("received: {:?}", first_event); } } _ => (), } // receive stream element and validate match first_event { (quick_xml::name::ResolveResult::Bound(ns), Event::Start(e)) => { if ns.0 == crate::stanza::stream::XMLNS.as_bytes() { e.attributes().try_for_each(|attr| -> Result<()> { let attr = attr?; match attr.key.into_inner() { b"from" => { self.server = str::from_utf8(&attr.value)?.to_owned(); Ok(()) } _ => Ok(()), } }); return Ok(()); } else { return Err(JabberError::BadStream); } } // TODO: errors for incorrect namespace _ => Err(JabberError::BadStream), } } } // pub async fn get_features(&mut self) -> Result> { // Element::read(&mut self.reader).await?.try_into() // } impl Jabber { pub async fn starttls(&mut self) -> Result> { todo!() } // let mut starttls_element = BytesStart::new("starttls"); // starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); // self.writer // .write_event_async(Event::Empty(starttls_element)) // .await // .unwrap(); // let mut buf = Vec::new(); // match self.reader.read_event_into_async(&mut buf).await.unwrap() { // Event::Empty(e) => match e.name() { // QName(b"proceed") => { // let connector = TlsConnector::new().unwrap(); // let stream = self // .reader // .into_inner() // .into_inner() // .unsplit(self.writer.into_inner()); // if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) // .connect(&self.jabber.server, stream) // .await // { // let (read, write) = tokio::io::split(tlsstream); // let reader = Reader::from_reader(BufReader::new(read)); // let writer = Writer::new(write); // let mut client = // super::encrypted::JabberClient::new(reader, writer, self.jabber); // client.start_stream().await?; // return Ok(client); // } // } // QName(_) => return Err(JabberError::TlsNegotiation), // }, // _ => return Err(JabberError::TlsNegotiation), // } // Err(JabberError::TlsNegotiation) // } } impl std::fmt::Debug for Jabber { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"tls") .field("jid", &self.jid) .field("auth", &self.auth) .field("server", &self.server) .finish() } } impl std::fmt::Debug for Jabber { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Jabber") .field("connection", &"unencrypted") .field("jid", &self.jid) .field("auth", &self.auth) .field("server", &self.server) .finish() } } #[cfg(test)] mod tests { use super::*; use crate::connection::Connection; use test_log::test; #[test(tokio::test)] async fn start_stream() { let connection = Connection::connect("blos.sm").await.unwrap(); match connection { Connection::Encrypted(mut c) => c.start_stream().await.unwrap(), Connection::Unencrypted(mut c) => c.start_stream().await.unwrap(), } } }