diff options
author | 2025-04-13 10:38:58 +0100 | |
---|---|---|
committer | 2025-04-13 10:38:58 +0100 | |
commit | 1c55b582a8e874b5dc776950c92a1ecf622d03fe (patch) | |
tree | cc3f21fa0a8d1fa6ea32ab0c3fd391c14c3ba29f /src/reader.rs | |
parent | c658ab440f8e69ac406b18732dbf276c084926b6 (diff) | |
download | peanuts-1c55b582a8e874b5dc776950c92a1ecf622d03fe.tar.gz peanuts-1c55b582a8e874b5dc776950c92a1ecf622d03fe.tar.bz2 peanuts-1c55b582a8e874b5dc776950c92a1ecf622d03fe.zip |
feat: websocket-framed reader and writer
Diffstat (limited to 'src/reader.rs')
-rw-r--r-- | src/reader.rs | 184 |
1 files changed, 168 insertions, 16 deletions
diff --git a/src/reader.rs b/src/reader.rs index c4d85f7..0cca93f 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,12 +1,25 @@ +#[cfg(target_arch = "wasm32")] +use std::io::Write; + use circular::Buffer; +#[cfg(target_arch = "wasm32")] +use js_sys::{ArrayBuffer, Uint8Array}; use nom::Err; use std::{ collections::{HashMap, HashSet, VecDeque}, str, }; use tokio::io::{AsyncRead, AsyncReadExt}; +#[cfg(target_arch = "wasm32")] +use tokio::sync::mpsc; use tracing::{debug, info, trace}; +#[cfg(target_arch = "wasm32")] +use wasm_bindgen::{closure::Closure, JsCast}; +#[cfg(target_arch = "wasm32")] +use web_sys::{Blob, MessageEvent}; +#[cfg(target_arch = "wasm32")] +use crate::error::WebsocketError; use crate::{ declaration::{Declaration, VersionInfo}, element::{Content, Element, FromElement, Name, NamespaceDeclaration}, @@ -26,9 +39,124 @@ pub struct Reader<R> { // to have names reference namespaces could depth: Vec<Name>, namespace_declarations: Vec<HashSet<NamespaceDeclaration>>, + unendable: bool, root_ended: bool, } +/// Represents a WebSocket Message, after converting from JavaScript type. +/// from https://github.com/najamelan/ws_stream_wasm/blob/dev/src/ws_message.rs +#[cfg(target_arch = "wasm32")] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum WsMessage { + /// The data of the message is a string. + /// + Text(String), + + /// The message contains binary data. + /// + Binary(Vec<u8>), +} + +/// This will convert the JavaScript event into a WsMessage. Note that this +/// will only work if the connection is set to use the binary type ArrayBuffer. +/// On binary type Blob, this will panic. +/// from https://github.com/najamelan/ws_stream_wasm/blob/dev/src/ws_message.rs +#[cfg(target_arch = "wasm32")] +impl TryFrom<MessageEvent> for WsMessage { + type Error = WebsocketError; + + fn try_from(evt: MessageEvent) -> std::result::Result<Self, Self::Error> { + match evt.data() { + d if d.is_instance_of::<ArrayBuffer>() => { + let buffy = Uint8Array::new(d.unchecked_ref()); + let mut v = vec![0; buffy.length() as usize]; + + buffy.copy_to(&mut v); // FIXME: get rid of this copy + + Ok(WsMessage::Binary(v)) + } + + // We don't allow invalid encodings. In principle if needed, + // we could add a variant to WsMessage with a CString or an OsString + // to allow the user to access this data. However until there is a usecase, + // I'm not inclined, amongst other things because the conversion from Js isn't very + // clear and it would require a bunch of testing for something that's a rather bad + // idea to begin with. If you need data that is not a valid string, use a binary + // message. + // + d if d.is_string() => match d.as_string() { + Some(text) => Ok(WsMessage::Text(text)), + None => Err(WebsocketError::InvalidEncoding), + }, + + // We have set the binary mode to array buffer (WsMeta::connect), so normally this shouldn't happen. + // That is as long as this is used within the context of the WsMeta constructor. + // + d if d.is_instance_of::<Blob>() => Err(WebsocketError::CantDecodeBlob), + + // should never happen. + // + _ => Err(WebsocketError::UnknownDataType), + } + } +} + +#[cfg(target_arch = "wasm32")] +pub struct WebSocketOnMessageRead { + queue: mpsc::UnboundedReceiver<WsMessage>, +} + +#[cfg(target_arch = "wasm32")] +impl WebSocketOnMessageRead { + pub fn new() -> (Closure<dyn FnMut(MessageEvent)>, Self) { + let (send, recv) = mpsc::unbounded_channel(); + let on_msg = Closure::wrap(Box::new(move |msg_evt: MessageEvent| { + let msg_evt = msg_evt.try_into(); + match msg_evt { + Ok(msg_evt) => match send.send(msg_evt) { + Ok(()) => {} + Err(e) => { + tracing::error!("message event send error: {:?}", e); + } + }, + Err(e) => { + tracing::error!("websocket receive error: {}", e); + } + } + }) as Box<dyn FnMut(MessageEvent)>); + + (on_msg, Self { queue: recv }) + } +} + +#[cfg(target_arch = "wasm32")] +impl Readable for WebSocketOnMessageRead { + async fn read_buf(&mut self, buffer: &mut Buffer) -> Result<usize> { + let mut queue = Vec::new(); + self.queue.recv_many(&mut queue, 10).await; + let mut bytes = 0; + for msg in queue { + match msg { + WsMessage::Text(s) => { + let text = s.as_bytes(); + bytes += buffer.write(text)?; + } + WsMessage::Binary(v) => { + bytes += buffer.write(&v)?; + } + } + } + Ok(bytes) + } +} + +pub trait Readable { + fn read_buf( + &mut self, + buffer: &mut Buffer, + ) -> impl std::future::Future<Output = Result<usize>> + Send; +} + impl<R> Reader<R> { pub fn new(reader: R) -> Self { let mut default_declarations = HashSet::new(); @@ -46,6 +174,28 @@ impl<R> Reader<R> { depth: Vec::new(), // TODO: make sure reserved namespaces are never overwritten namespace_declarations: vec![default_declarations], + unendable: false, + root_ended: false, + } + } + + pub fn new_unendable(reader: R) -> Self { + let mut default_declarations = HashSet::new(); + default_declarations.insert(NamespaceDeclaration { + prefix: Some("xml".to_string()), + namespace: XML_NS.to_string(), + }); + default_declarations.insert(NamespaceDeclaration { + prefix: Some("xmlns".to_string()), + namespace: XMLNS_NS.to_string(), + }); + Self { + inner: reader, + buffer: Buffer::with_capacity(MAX_STANZA_SIZE), + depth: Vec::new(), + // TODO: make sure reserved namespaces are never overwritten + namespace_declarations: vec![default_declarations], + unendable: true, root_ended: false, } } @@ -55,16 +205,18 @@ impl<R> Reader<R> { } } -impl<R> Reader<R> +impl<R> Readable for R where - R: AsyncRead + Unpin, + R: AsyncRead + Unpin + Send, { - pub async fn read_buf<'s>(&mut self) -> Result<usize> { - Ok(self.inner.read_buf(&mut self.buffer).await?) + async fn read_buf(&mut self, buffer: &mut Buffer) -> Result<usize> { + Ok(tokio::io::AsyncReadExt::read_buf(self, buffer).await?) } +} +impl<R: Readable> Reader<R> { pub async fn read_prolog<'s>(&'s mut self) -> Result<Option<Declaration>> { - if self.root_ended { + if !self.unendable && self.root_ended { return Err(Error::RootElementEnded); } loop { @@ -104,7 +256,7 @@ where } std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { @@ -131,7 +283,7 @@ where } pub async fn read_start_tag<'s>(&'s mut self) -> Result<Element> { - if self.root_ended { + if !self.unendable && self.root_ended { return Err(Error::RootElementEnded); } loop { @@ -151,7 +303,7 @@ where } std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { @@ -166,7 +318,7 @@ where } pub async fn read_end_tag<'s>(&'s mut self) -> Result<()> { - if self.root_ended { + if !self.unendable && self.root_ended { return Err(Error::RootElementEnded); } loop { @@ -189,7 +341,7 @@ where } std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { @@ -204,7 +356,7 @@ where } pub async fn read_element<'s>(&'s mut self) -> Result<Element> { - if self.root_ended { + if !self.unendable && self.root_ended { return Err(Error::RootElementEnded); } loop { @@ -224,7 +376,7 @@ where } std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { @@ -239,7 +391,7 @@ where } pub async fn read_content<'s>(&'s mut self) -> Result<Content> { - if self.root_ended { + if !self.unendable && self.root_ended { return Err(Error::RootElementEnded); } let mut last_char = false; @@ -256,7 +408,7 @@ where } std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } _ => match xml::ContentItem::parse(input) { Ok((rest, content_item)) => match content_item { @@ -313,7 +465,7 @@ where }, std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { @@ -378,7 +530,7 @@ where }, std::result::Result::Err(e) => match e { Err::Incomplete(_) => { - self.read_buf().await?; + self.inner.read_buf(&mut self.buffer).await?; } // TODO: better error Err::Error(e) => { |