diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/connection.rs | 15 | ||||
| -rw-r--r-- | src/jabber.rs | 85 | ||||
| -rw-r--r-- | src/lib.rs | 3 | ||||
| -rw-r--r-- | src/stanza/starttls.rs | 162 | ||||
| -rw-r--r-- | src/stanza/stream.rs | 73 | 
5 files changed, 290 insertions, 48 deletions
| diff --git a/src/connection.rs b/src/connection.rs index 89f382f..2b70747 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -27,8 +27,11 @@ impl Connection {          match self {              Connection::Encrypted(j) => Ok(j),              Connection::Unencrypted(mut j) => { +                j.start_stream().await?;                  info!("upgrading connection to tls"); -                Ok(j.starttls().await?) +                j.get_features().await?; +                let j = j.starttls().await?; +                Ok(j)              }          }      } @@ -179,4 +182,14 @@ mod tests {      async fn connect() {          Connection::connect("blos.sm").await.unwrap();      } + +    #[test(tokio::test)] +    async fn test_tls() { +        Connection::connect("blos.sm") +            .await +            .unwrap() +            .ensure_tls() +            .await +            .unwrap(); +    }  } diff --git a/src/jabber.rs b/src/jabber.rs index afe840b..87a2b44 100644 --- a/src/jabber.rs +++ b/src/jabber.rs @@ -1,14 +1,18 @@  use std::str;  use std::sync::Arc; +use peanuts::element::{FromElement, IntoElement};  use peanuts::{Reader, Writer};  use rsasl::prelude::SASLConfig;  use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf}; +use tokio_native_tls::native_tls::TlsConnector;  use tracing::{debug, info, trace}; +use trust_dns_resolver::proto::rr::domain::IntoLabel;  use crate::connection::{Tls, Unencrypted};  use crate::error::Error; -use crate::stanza::stream::Stream; +use crate::stanza::starttls::{Proceed, StartTls}; +use crate::stanza::stream::{Features, Stream};  use crate::stanza::XML_VERSION;  use crate::Result;  use crate::JID; @@ -62,7 +66,6 @@ where          // opening stream element          let server = self.server.clone().try_into()?;          let stream = Stream::new_client(None, server, None, "en".to_string()); -        // TODO: nicer function to serialize to xml writer          self.writer.write_start(&stream).await?;          // server to client @@ -72,57 +75,53 @@ where          // receive stream element and validate          let stream: Stream = self.reader.read_start().await?; +        debug!("got stream: {:?}", stream);          if let Some(from) = stream.from {              self.server = from.to_string()          }          Ok(())      } -} -// pub async fn get_features(&mut self) -> Result<Vec<StreamFeature>> { -//     Element::read(&mut self.reader).await?.try_into() -// } +    pub async fn get_features(&mut self) -> Result<Features> { +        debug!("getting features"); +        let features: Features = self.reader.read().await?; +        debug!("got features: {:?}", features); +        Ok(features) +    } + +    pub fn into_inner(self) -> S { +        self.reader.into_inner().unsplit(self.writer.into_inner()) +    } +}  impl Jabber<Unencrypted> { -    pub async fn starttls(&mut self) -> Result<Jabber<Tls>> { -        todo!() +    pub async fn starttls(mut self) -> Result<Jabber<Tls>> { +        self.writer +            .write_full(&StartTls { required: false }) +            .await?; +        let proceed: Proceed = self.reader.read().await?; +        debug!("got proceed: {:?}", proceed); +        let connector = TlsConnector::new().unwrap(); +        let stream = self.reader.into_inner().unsplit(self.writer.into_inner()); +        if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) +            .connect(&self.server, stream) +            .await +        { +            let (read, write) = tokio::io::split(tlsstream); +            let mut client = Jabber::new( +                read, +                write, +                self.jid.to_owned(), +                self.auth.to_owned(), +                self.server.to_owned(), +            ); +            client.start_stream().await?; +            return Ok(client); +        } else { +            return Err(Error::Connection); +        }      } -    //     let mut starttls_element = BytesStart::new("starttls"); -    //     starttls_element.push_attribute(("xmlns", "urn:ietf:params:xml:ns:xmpp-tls")); -    //     self.writer -    //         .write_event_async(Event::Empty(starttls_element)) -    //         .await -    //         .unwrap(); -    //     let mut buf = Vec::new(); -    //     match self.reader.read_event_into_async(&mut buf).await.unwrap() { -    //         Event::Empty(e) => match e.name() { -    //             QName(b"proceed") => { -    //                 let connector = TlsConnector::new().unwrap(); -    //                 let stream = self -    //                     .reader -    //                     .into_inner() -    //                     .into_inner() -    //                     .unsplit(self.writer.into_inner()); -    //                 if let Ok(tlsstream) = tokio_native_tls::TlsConnector::from(connector) -    //                     .connect(&self.jabber.server, stream) -    //                     .await -    //                 { -    //                     let (read, write) = tokio::io::split(tlsstream); -    //                     let reader = Reader::from_reader(BufReader::new(read)); -    //                     let writer = Writer::new(write); -    //                     let mut client = -    //                         super::encrypted::JabberClient::new(reader, writer, self.jabber); -    //                     client.start_stream().await?; -    //                     return Ok(client); -    //                 } -    //             } -    //             QName(_) => return Err(JabberError::TlsNegotiation), -    //         }, -    //         _ => return Err(JabberError::TlsNegotiation), -    //     } -    //     Err(JabberError::TlsNegotiation) -    // }  }  impl std::fmt::Debug for Jabber<Tls> { @@ -8,9 +8,6 @@ pub mod jabber;  pub mod jid;  pub mod stanza; -#[macro_use] -extern crate lazy_static; -  pub use connection::Connection;  pub use error::Error;  pub use jabber::Jabber; diff --git a/src/stanza/starttls.rs b/src/stanza/starttls.rs index 8b13789..874ae66 100644 --- a/src/stanza/starttls.rs +++ b/src/stanza/starttls.rs @@ -1 +1,163 @@ +use std::collections::{HashMap, HashSet}; +use peanuts::{ +    element::{Content, FromElement, IntoElement, Name, NamespaceDeclaration}, +    Element, +}; + +pub const XMLNS: &str = "urn:ietf:params:xml:ns:xmpp-tls"; + +#[derive(Debug)] +pub struct StartTls { +    pub required: bool, +} + +impl IntoElement for StartTls { +    fn into_element(&self) -> peanuts::Element { +        let content; +        if self.required == true { +            let element = Content::Element(Element { +                name: Name { +                    namespace: Some(XMLNS.to_string()), +                    local_name: "required".to_string(), +                }, +                namespace_declarations: HashSet::new(), +                attributes: HashMap::new(), +                content: Vec::new(), +            }); +            content = vec![element]; +        } else { +            content = Vec::new(); +        } +        let mut namespace_declarations = HashSet::new(); +        namespace_declarations.insert(NamespaceDeclaration { +            prefix: None, +            namespace: XMLNS.to_string(), +        }); +        Element { +            name: Name { +                namespace: Some(XMLNS.to_string()), +                local_name: "starttls".to_string(), +            }, +            namespace_declarations, +            attributes: HashMap::new(), +            content, +        } +    } +} + +impl FromElement for StartTls { +    fn from_element(element: peanuts::Element) -> peanuts::Result<Self> { +        let Name { +            namespace, +            local_name, +        } = element.name; +        if namespace.as_deref() == Some(XMLNS) && &local_name == "starttls" { +            let mut required = false; +            if element.content.len() == 1 { +                match element.content.first().unwrap() { +                    Content::Element(element) => { +                        let Name { +                            namespace, +                            local_name, +                        } = &element.name; + +                        if namespace.as_deref() == Some(XMLNS) && local_name == "required" { +                            required = true +                        } else { +                            return Err(peanuts::Error::UnexpectedElement(element.name.clone())); +                        } +                    } +                    c => return Err(peanuts::Error::UnexpectedContent((*c).clone())), +                } +            } else { +                return Err(peanuts::Error::UnexpectedNumberOfContents( +                    element.content.len(), +                )); +            } +            return Ok(StartTls { required }); +        } else { +            return Err(peanuts::Error::IncorrectName(Name { +                namespace, +                local_name, +            })); +        } +    } +} + +#[derive(Debug)] +pub struct Proceed; + +impl IntoElement for Proceed { +    fn into_element(&self) -> Element { +        let mut namespace_declarations = HashSet::new(); +        namespace_declarations.insert(NamespaceDeclaration { +            prefix: None, +            namespace: XMLNS.to_string(), +        }); +        Element { +            name: Name { +                namespace: Some(XMLNS.to_string()), +                local_name: "proceed".to_string(), +            }, +            namespace_declarations, +            attributes: HashMap::new(), +            content: Vec::new(), +        } +    } +} + +impl FromElement for Proceed { +    fn from_element(element: Element) -> peanuts::Result<Self> { +        let Name { +            namespace, +            local_name, +        } = element.name; +        if namespace.as_deref() == Some(XMLNS) && &local_name == "proceed" { +            return Ok(Proceed); +        } else { +            return Err(peanuts::Error::IncorrectName(Name { +                namespace, +                local_name, +            })); +        } +    } +} + +pub struct Failure; + +impl IntoElement for Failure { +    fn into_element(&self) -> Element { +        let mut namespace_declarations = HashSet::new(); +        namespace_declarations.insert(NamespaceDeclaration { +            prefix: None, +            namespace: XMLNS.to_string(), +        }); +        Element { +            name: Name { +                namespace: Some(XMLNS.to_string()), +                local_name: "failure".to_string(), +            }, +            namespace_declarations, +            attributes: HashMap::new(), +            content: Vec::new(), +        } +    } +} + +impl FromElement for Failure { +    fn from_element(element: Element) -> peanuts::Result<Self> { +        let Name { +            namespace, +            local_name, +        } = element.name; +        if namespace.as_deref() == Some(XMLNS) && &local_name == "failure" { +            return Ok(Failure); +        } else { +            return Err(peanuts::Error::IncorrectName(Name { +                namespace, +                local_name, +            })); +        } +    } +} diff --git a/src/stanza/stream.rs b/src/stanza/stream.rs index ac4badc..4516682 100644 --- a/src/stanza/stream.rs +++ b/src/stanza/stream.rs @@ -6,12 +6,15 @@ use peanuts::{element::Name, Element};  use crate::{Error, JID}; +use super::starttls::StartTls; +  pub const XMLNS: &str = "http://etherx.jabber.org/streams";  pub const XMLNS_CLIENT: &str = "jabber:client";  // MUST be qualified by stream namespace  // #[derive(XmlSerialize, XmlDeserialize)]  // #[peanuts(xmlns = XMLNS)] +#[derive(Debug)]  pub struct Stream {      pub from: Option<JID>,      to: Option<JID>, @@ -93,7 +96,7 @@ impl IntoElement for Stream {              attributes.insert(                  Name {                      namespace: None, -                    local_name: "version".to_string(), +                    local_name: "id".to_string(),                  },                  id.clone(),              ); @@ -158,3 +161,71 @@ impl<'s> Stream {          }      }  } + +#[derive(Debug)] +pub struct Features { +    features: Vec<Feature>, +} + +impl IntoElement for Features { +    fn into_element(&self) -> Element { +        let mut content = Vec::new(); +        for feature in &self.features { +            match feature { +                Feature::StartTls(start_tls) => { +                    content.push(Content::Element(start_tls.into_element())) +                } +                Feature::Sasl => {} +                Feature::Bind => {} +                Feature::Unknown => {} +            } +        } +        Element { +            name: Name { +                namespace: Some(XMLNS.to_string()), +                local_name: "features".to_string(), +            }, +            namespace_declarations: HashSet::new(), +            attributes: HashMap::new(), +            content, +        } +    } +} + +impl FromElement for Features { +    fn from_element(element: Element) -> peanuts::Result<Self> { +        let Name { +            namespace, +            local_name, +        } = element.name; +        if namespace.as_deref() == Some(XMLNS) && &local_name == "features" { +            let mut features = Vec::new(); +            for feature in element.content { +                match feature { +                    Content::Element(element) => { +                        if let Ok(start_tls) = FromElement::from_element(element) { +                            features.push(Feature::StartTls(start_tls)) +                        } else { +                            features.push(Feature::Unknown) +                        } +                    } +                    c => return Err(peanuts::Error::UnexpectedContent(c.clone())), +                } +            } +            return Ok(Self { features }); +        } else { +            return Err(peanuts::Error::IncorrectName(Name { +                namespace, +                local_name, +            })); +        } +    } +} + +#[derive(Debug)] +pub enum Feature { +    StartTls(StartTls), +    Sasl, +    Bind, +    Unknown, +} | 
