use std::fmt::{self, Display, Formatter}; use std::io::{self, prelude::*}; use std::str; #[derive(Debug, PartialEq)] pub enum MarkupDisplay where T: Display, { Safe(T), Unsafe(T), } impl MarkupDisplay where T: Display, { pub fn mark_safe(self) -> MarkupDisplay { match self { MarkupDisplay::Unsafe(t) => MarkupDisplay::Safe(t), _ => self, } } } impl From for MarkupDisplay where T: Display, { fn from(t: T) -> MarkupDisplay { MarkupDisplay::Unsafe(t) } } impl Display for MarkupDisplay where T: Display, { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match *self { MarkupDisplay::Unsafe(ref t) => { let mut w = EscapeWriter { fmt: f }; write!(w, "{}", t).map_err(|_e| fmt::Error) } MarkupDisplay::Safe(ref t) => t.fmt(f), } } } pub struct EscapeWriter<'a, 'b: 'a> { fmt: &'a mut fmt::Formatter<'b>, } impl io::Write for EscapeWriter<'_, '_> { fn write(&mut self, bytes: &[u8]) -> io::Result { let escaped = Escaped { bytes }; escaped .fmt(self.fmt) .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; Ok(bytes.len()) } fn flush(&mut self) -> std::io::Result<()> { Ok(()) } } pub fn escape(s: &str) -> Escaped<'_> { Escaped { bytes: s.as_bytes(), } } macro_rules! escaping_body { ($start:ident, $i:ident, $fmt:ident, $_self:ident, $quote:expr) => {{ if $start < $i { $fmt.write_str(unsafe { str::from_utf8_unchecked(&$_self.bytes[$start..$i]) })?; } $fmt.write_str($quote)?; $start = $i + 1; }}; } pub struct Escaped<'a> { bytes: &'a [u8], } impl<'a> ::std::fmt::Display for Escaped<'a> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let mut start = 0; for (i, b) in self.bytes.iter().enumerate() { if b.wrapping_sub(b'"') <= FLAG { match *b { b'<' => escaping_body!(start, i, fmt, self, "<"), b'>' => escaping_body!(start, i, fmt, self, ">"), b'&' => escaping_body!(start, i, fmt, self, "&"), b'"' => escaping_body!(start, i, fmt, self, """), b'\'' => escaping_body!(start, i, fmt, self, "'"), b'/' => escaping_body!(start, i, fmt, self, "/"), _ => (), } } } fmt.write_str(unsafe { str::from_utf8_unchecked(&self.bytes[start..]) })?; Ok(()) } } const FLAG: u8 = b'>' - b'"'; #[cfg(test)] mod tests { use super::*; #[test] fn test_escape() { assert_eq!(escape("").to_string(), ""); assert_eq!(escape("<&>").to_string(), "<&>"); assert_eq!(escape("bla&").to_string(), "bla&"); assert_eq!(escape("