From 467f4ade19fa34983de7e6f6d81c6b4d5ff140fe Mon Sep 17 00:00:00 2001
From: Dirkjan Ochtman <dirkjan@ochtman.nl>
Date: Thu, 10 Jan 2019 17:19:28 +0100
Subject: Specify a trait that handles the output format's escaping

---
 askama_escape/benches/all.rs |  12 ++--
 askama_escape/src/lib.rs     | 163 ++++++++++++++++++++++++++++---------------
 2 files changed, 114 insertions(+), 61 deletions(-)

(limited to 'askama_escape')

diff --git a/askama_escape/benches/all.rs b/askama_escape/benches/all.rs
index af28c43..a98f2d7 100644
--- a/askama_escape/benches/all.rs
+++ b/askama_escape/benches/all.rs
@@ -1,7 +1,7 @@
 #[macro_use]
 extern crate criterion;
 
-use askama_escape::MarkupDisplay;
+use askama_escape::{Html, MarkupDisplay};
 use criterion::Criterion;
 
 criterion_main!(benches);
@@ -68,10 +68,10 @@ quis lacus at, gravida maximus elit. Duis tristique, nisl nullam.
     "#;
 
     b.iter(|| {
-        format!("{}", MarkupDisplay::from(string_long));
-        format!("{}", MarkupDisplay::from(string_short));
-        format!("{}", MarkupDisplay::from(empty));
-        format!("{}", MarkupDisplay::from(no_escape));
-        format!("{}", MarkupDisplay::from(no_escape_long));
+        format!("{}", MarkupDisplay::new_unsafe(string_long, Html));
+        format!("{}", MarkupDisplay::new_unsafe(string_short, Html));
+        format!("{}", MarkupDisplay::new_unsafe(empty, Html));
+        format!("{}", MarkupDisplay::new_unsafe(no_escape, Html));
+        format!("{}", MarkupDisplay::new_unsafe(no_escape_long, Html));
     });
 }
diff --git a/askama_escape/src/lib.rs b/askama_escape/src/lib.rs
index 48d43ca..01da4ed 100644
--- a/askama_escape/src/lib.rs
+++ b/askama_escape/src/lib.rs
@@ -2,58 +2,76 @@ use std::fmt::{self, Display, Formatter};
 use std::io::{self, prelude::*};
 use std::str;
 
-#[derive(Debug, PartialEq)]
-pub enum MarkupDisplay<T>
+pub struct MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    Safe(T),
-    Unsafe(T),
+    value: DisplayValue<T>,
+    escaper: E,
 }
 
-impl<T> MarkupDisplay<T>
+impl<E, T> MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    pub fn mark_safe(self) -> MarkupDisplay<T> {
-        match self {
-            MarkupDisplay::Unsafe(t) => MarkupDisplay::Safe(t),
-            _ => self,
+    pub fn new_unsafe(value: T, escaper: E) -> Self {
+        Self {
+            value: DisplayValue::Unsafe(value),
+            escaper,
         }
     }
-}
 
-impl<T> From<T> for MarkupDisplay<T>
-where
-    T: Display,
-{
-    fn from(t: T) -> MarkupDisplay<T> {
-        MarkupDisplay::Unsafe(t)
+    pub fn new_safe(value: T, escaper: E) -> Self {
+        Self {
+            value: DisplayValue::Safe(value),
+            escaper,
+        }
+    }
+
+    pub fn mark_safe(mut self) -> MarkupDisplay<E, T> {
+        self.value = match self.value {
+            DisplayValue::Unsafe(t) => DisplayValue::Safe(t),
+            _ => self.value,
+        };
+        self
     }
 }
 
-impl<T> Display for MarkupDisplay<T>
+impl<E, T> Display for MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
-        match *self {
-            MarkupDisplay::Unsafe(ref t) => {
-                let mut w = EscapeWriter { fmt: f };
-                write!(w, "{}", t).map_err(|_e| fmt::Error)
-            }
-            MarkupDisplay::Safe(ref t) => t.fmt(f),
+    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
+        match self.value {
+            DisplayValue::Unsafe(ref t) => write!(
+                EscapeWriter {
+                    fmt,
+                    escaper: &self.escaper
+                },
+                "{}",
+                t
+            )
+            .map_err(|_| fmt::Error),
+            DisplayValue::Safe(ref t) => t.fmt(fmt),
         }
     }
 }
 
-pub struct EscapeWriter<'a, 'b: 'a> {
+pub struct EscapeWriter<'a, 'b: 'a, E> {
     fmt: &'a mut fmt::Formatter<'b>,
+    escaper: &'a E,
 }
 
-impl io::Write for EscapeWriter<'_, '_> {
+impl<E> io::Write for EscapeWriter<'_, '_, E>
+where
+    E: Escaper,
+{
     fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
-        write_escaped_str(self.fmt, bytes)
+        self.escaper
+            .write_escaped_bytes(self.fmt, bytes)
             .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
         Ok(bytes.len())
     }
@@ -63,12 +81,35 @@ impl io::Write for EscapeWriter<'_, '_> {
     }
 }
 
-pub fn escape(s: &str) -> Escaped<'_> {
+pub fn escape<E>(s: &str, escaper: E) -> Escaped<'_, E>
+where
+    E: Escaper,
+{
     Escaped {
         bytes: s.as_bytes(),
+        escaper,
+    }
+}
+
+pub struct Escaped<'a, E>
+where
+    E: Escaper,
+{
+    bytes: &'a [u8],
+    escaper: E,
+}
+
+impl<'a, E> ::std::fmt::Display for Escaped<'a, E>
+where
+    E: Escaper,
+{
+    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.escaper.write_escaped_bytes(fmt, self.bytes)
     }
 }
 
+pub struct Html;
+
 macro_rules! escaping_body {
     ($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{
         if $start < $i {
@@ -79,33 +120,45 @@ macro_rules! escaping_body {
     }};
 }
 
-pub struct Escaped<'a> {
-    bytes: &'a [u8],
+impl Escaper for Html {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+        let mut start = 0;
+        for (i, b) in bytes.iter().enumerate() {
+            if b.wrapping_sub(b'"') <= FLAG {
+                match *b {
+                    b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
+                    b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
+                    b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
+                    b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
+                    b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
+                    b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
+                    _ => (),
+                }
+            }
+        }
+        fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })
+    }
 }
 
-impl<'a> ::std::fmt::Display for Escaped<'a> {
-    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write_escaped_str(fmt, self.bytes)
+pub struct Text;
+
+impl Escaper for Text {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+        fmt.write_str(unsafe { str::from_utf8_unchecked(bytes) })
     }
 }
 
-fn write_escaped_str(fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
-    let mut start = 0;
-    for (i, b) in bytes.iter().enumerate() {
-        if b.wrapping_sub(b'"') <= FLAG {
-            match *b {
-                b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
-                b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
-                b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
-                b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
-                b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
-                b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
-                _ => (),
-            }
-        }
-    }
-    fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })?;
-    Ok(())
+#[derive(Debug, PartialEq)]
+enum DisplayValue<T>
+where
+    T: Display,
+{
+    Safe(T),
+    Unsafe(T),
+}
+
+pub trait Escaper {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result;
 }
 
 const FLAG: u8 = b'>' - b'"';
@@ -115,10 +168,10 @@ mod tests {
     use super::*;
     #[test]
     fn test_escape() {
-        assert_eq!(escape("").to_string(), "");
-        assert_eq!(escape("<&>").to_string(), "&lt;&amp;&gt;");
-        assert_eq!(escape("bla&").to_string(), "bla&amp;");
-        assert_eq!(escape("<foo").to_string(), "&lt;foo");
-        assert_eq!(escape("bla&h").to_string(), "bla&amp;h");
+        assert_eq!(escape("", Html).to_string(), "");
+        assert_eq!(escape("<&>", Html).to_string(), "&lt;&amp;&gt;");
+        assert_eq!(escape("bla&", Html).to_string(), "bla&amp;");
+        assert_eq!(escape("<foo", Html).to_string(), "&lt;foo");
+        assert_eq!(escape("bla&h", Html).to_string(), "bla&amp;h");
     }
 }
-- 
cgit