diff options
Diffstat (limited to 'askama_escape/src')
-rw-r--r-- | askama_escape/src/lib.rs | 163 |
1 files changed, 108 insertions, 55 deletions
diff --git a/askama_escape/src/lib.rs b/askama_escape/src/lib.rs index 48d43ca..01da4ed 100644 --- a/askama_escape/src/lib.rs +++ b/askama_escape/src/lib.rs @@ -2,58 +2,76 @@ use std::fmt::{self, Display, Formatter}; use std::io::{self, prelude::*}; use std::str; -#[derive(Debug, PartialEq)] -pub enum MarkupDisplay<T> +pub struct MarkupDisplay<E, T> where + E: Escaper, T: Display, { - Safe(T), - Unsafe(T), + value: DisplayValue<T>, + escaper: E, } -impl<T> MarkupDisplay<T> +impl<E, T> MarkupDisplay<E, T> where + E: Escaper, T: Display, { - pub fn mark_safe(self) -> MarkupDisplay<T> { - match self { - MarkupDisplay::Unsafe(t) => MarkupDisplay::Safe(t), - _ => self, + pub fn new_unsafe(value: T, escaper: E) -> Self { + Self { + value: DisplayValue::Unsafe(value), + escaper, } } -} -impl<T> From<T> for MarkupDisplay<T> -where - T: Display, -{ - fn from(t: T) -> MarkupDisplay<T> { - MarkupDisplay::Unsafe(t) + pub fn new_safe(value: T, escaper: E) -> Self { + Self { + value: DisplayValue::Safe(value), + escaper, + } + } + + pub fn mark_safe(mut self) -> MarkupDisplay<E, T> { + self.value = match self.value { + DisplayValue::Unsafe(t) => DisplayValue::Safe(t), + _ => self.value, + }; + self } } -impl<T> Display for MarkupDisplay<T> +impl<E, T> Display for MarkupDisplay<E, T> where + E: Escaper, T: Display, { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - match *self { - MarkupDisplay::Unsafe(ref t) => { - let mut w = EscapeWriter { fmt: f }; - write!(w, "{}", t).map_err(|_e| fmt::Error) - } - MarkupDisplay::Safe(ref t) => t.fmt(f), + fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result { + match self.value { + DisplayValue::Unsafe(ref t) => write!( + EscapeWriter { + fmt, + escaper: &self.escaper + }, + "{}", + t + ) + .map_err(|_| fmt::Error), + DisplayValue::Safe(ref t) => t.fmt(fmt), } } } -pub struct EscapeWriter<'a, 'b: 'a> { +pub struct EscapeWriter<'a, 'b: 'a, E> { fmt: &'a mut fmt::Formatter<'b>, + escaper: &'a E, } -impl io::Write for EscapeWriter<'_, '_> { +impl<E> io::Write for EscapeWriter<'_, '_, E> +where + E: Escaper, +{ fn write(&mut self, bytes: &[u8]) -> io::Result<usize> { - write_escaped_str(self.fmt, bytes) + self.escaper + .write_escaped_bytes(self.fmt, bytes) .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; Ok(bytes.len()) } @@ -63,12 +81,35 @@ impl io::Write for EscapeWriter<'_, '_> { } } -pub fn escape(s: &str) -> Escaped<'_> { +pub fn escape<E>(s: &str, escaper: E) -> Escaped<'_, E> +where + E: Escaper, +{ Escaped { bytes: s.as_bytes(), + escaper, + } +} + +pub struct Escaped<'a, E> +where + E: Escaper, +{ + bytes: &'a [u8], + escaper: E, +} + +impl<'a, E> ::std::fmt::Display for Escaped<'a, E> +where + E: Escaper, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + self.escaper.write_escaped_bytes(fmt, self.bytes) } } +pub struct Html; + macro_rules! escaping_body { ($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{ if $start < $i { @@ -79,33 +120,45 @@ macro_rules! escaping_body { }}; } -pub struct Escaped<'a> { - bytes: &'a [u8], +impl Escaper for Html { + fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { + let mut start = 0; + for (i, b) in bytes.iter().enumerate() { + if b.wrapping_sub(b'"') <= FLAG { + match *b { + b'<' => escaping_body!(start, i, fmt, bytes, "<"), + b'>' => escaping_body!(start, i, fmt, bytes, ">"), + b'&' => escaping_body!(start, i, fmt, bytes, "&"), + b'"' => escaping_body!(start, i, fmt, bytes, """), + b'\'' => escaping_body!(start, i, fmt, bytes, "'"), + b'/' => escaping_body!(start, i, fmt, bytes, "/"), + _ => (), + } + } + } + fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) }) + } } -impl<'a> ::std::fmt::Display for Escaped<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write_escaped_str(fmt, self.bytes) +pub struct Text; + +impl Escaper for Text { + fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { + fmt.write_str(unsafe { str::from_utf8_unchecked(bytes) }) } } -fn write_escaped_str(fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result { - let mut start = 0; - for (i, b) in bytes.iter().enumerate() { - if b.wrapping_sub(b'"') <= FLAG { - match *b { - b'<' => escaping_body!(start, i, fmt, bytes, "<"), - b'>' => escaping_body!(start, i, fmt, bytes, ">"), - b'&' => escaping_body!(start, i, fmt, bytes, "&"), - b'"' => escaping_body!(start, i, fmt, bytes, """), - b'\'' => escaping_body!(start, i, fmt, bytes, "'"), - b'/' => escaping_body!(start, i, fmt, bytes, "/"), - _ => (), - } - } - } - fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })?; - Ok(()) +#[derive(Debug, PartialEq)] +enum DisplayValue<T> +where + T: Display, +{ + Safe(T), + Unsafe(T), +} + +pub trait Escaper { + fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result; } const FLAG: u8 = b'>' - b'"'; @@ -115,10 +168,10 @@ mod tests { use super::*; #[test] fn test_escape() { - assert_eq!(escape("").to_string(), ""); - assert_eq!(escape("<&>").to_string(), "<&>"); - assert_eq!(escape("bla&").to_string(), "bla&"); - assert_eq!(escape("<foo").to_string(), "<foo"); - assert_eq!(escape("bla&h").to_string(), "bla&h"); + assert_eq!(escape("", Html).to_string(), ""); + assert_eq!(escape("<&>", Html).to_string(), "<&>"); + assert_eq!(escape("bla&", Html).to_string(), "bla&"); + assert_eq!(escape("<foo", Html).to_string(), "<foo"); + assert_eq!(escape("bla&h", Html).to_string(), "bla&h"); } } |