From 467f4ade19fa34983de7e6f6d81c6b4d5ff140fe Mon Sep 17 00:00:00 2001
From: Dirkjan Ochtman <dirkjan@ochtman.nl>
Date: Thu, 10 Jan 2019 17:19:28 +0100
Subject: Specify a trait that handles the output format's escaping

---
 askama/Cargo.toml                 |   1 +
 askama/src/lib.rs                 |   2 +
 askama_derive/src/generator.rs    |  19 +++--
 askama_derive/src/input.rs        |  48 ++++-------
 askama_escape/benches/all.rs      |  12 +--
 askama_escape/src/lib.rs          | 163 +++++++++++++++++++++++++-------------
 askama_shared/src/filters/json.rs |  13 +--
 askama_shared/src/filters/mod.rs  |  35 ++++----
 8 files changed, 172 insertions(+), 121 deletions(-)

diff --git a/askama/Cargo.toml b/askama/Cargo.toml
index 22d129d..2a2e1f6 100644
--- a/askama/Cargo.toml
+++ b/askama/Cargo.toml
@@ -28,6 +28,7 @@ with-gotham = ["gotham", "askama_derive/gotham", "hyper", "mime_guess"]
 
 [dependencies]
 askama_derive = { version = "0.8.0", path = "../askama_derive" }
+askama_escape = { version = "0.1.0", path = "../askama_escape" }
 askama_shared = { version = "0.7.2", path = "../askama_shared" }
 iron = { version = ">= 0.5, < 0.7", optional = true }
 rocket = { version = "0.4", optional = true }
diff --git a/askama/src/lib.rs b/askama/src/lib.rs
index 19d37a3..d15d1b6 100644
--- a/askama/src/lib.rs
+++ b/askama/src/lib.rs
@@ -417,6 +417,8 @@ use std::fs::{self, DirEntry};
 use std::io;
 use std::path::Path;
 
+pub use askama_escape::{Html, Text};
+
 /// Main `Template` trait; implementations are generally derived
 pub trait Template {
     /// Helper method which allocates a new `String` and renders into it
diff --git a/askama_derive/src/generator.rs b/askama_derive/src/generator.rs
index 16a08bb..f6c83ec 100644
--- a/askama_derive/src/generator.rs
+++ b/askama_derive/src/generator.rs
@@ -732,14 +732,14 @@ impl<'a> Generator<'a> {
                 }
                 Writable::Expr(s) => {
                     use self::DisplayWrap::*;
-                    use super::input::EscapeMode::*;
                     let mut expr_buf = Buffer::new(0);
                     let wrapped = self.visit_expr(&mut expr_buf, s);
-                    let expression = match (wrapped, &self.input.escaping) {
-                        (Wrapped, &Html) | (Wrapped, &None) | (Unwrapped, &None) => expr_buf.buf,
-                        (Unwrapped, &Html) => {
-                            format!("::askama::MarkupDisplay::from(&{})", expr_buf.buf)
-                        }
+                    let expression = match wrapped {
+                        Wrapped => expr_buf.buf,
+                        Unwrapped => format!(
+                            "::askama::MarkupDisplay::new_unsafe(&{}, {})",
+                            expr_buf.buf, self.input.escaping
+                        ),
                     };
 
                     let id = expr_cache.entry(expression.clone()).or_insert_with(|| {
@@ -876,7 +876,12 @@ impl<'a> Generator<'a> {
             return DisplayWrap::Unwrapped;
         }
 
-        if filters::BUILT_IN_FILTERS.contains(&name) {
+        if name == "escape" || name == "safe" || name == "e" || name == "json" {
+            buf.write(&format!(
+                "::askama::filters::{}({}, &",
+                name, self.input.escaping
+            ));
+        } else if filters::BUILT_IN_FILTERS.contains(&name) {
             buf.write(&format!("::askama::filters::{}(&", name));
         } else {
             buf.write(&format!("filters::{}(&", name));
diff --git a/askama_derive/src/input.rs b/askama_derive/src/input.rs
index c23d30f..b584298 100644
--- a/askama_derive/src/input.rs
+++ b/askama_derive/src/input.rs
@@ -15,7 +15,7 @@ pub struct TemplateInput<'a> {
     pub syntax: &'a Syntax<'a>,
     pub source: Source,
     pub print: Print,
-    pub escaping: EscapeMode,
+    pub escaping: &'a str,
     pub ext: Option<String>,
     pub parent: Option<&'a syn::Type>,
     pub path: PathBuf,
@@ -91,7 +91,7 @@ impl<'a> TemplateInput<'a> {
                         }
                         "escape" => {
                             if let syn::Lit::Str(ref s) = pair.lit {
-                                escaping = Some(s.value().into());
+                                escaping = Some(s.value());
                             } else {
                                 panic!("escape value must be string literal");
                             }
@@ -165,12 +165,24 @@ impl<'a> TemplateInput<'a> {
             },
         );
 
+        let escaping = escaping.unwrap_or_else(|| {
+            path.extension()
+                .map(|s| s.to_str().unwrap())
+                .unwrap_or("none")
+                .to_string()
+        });
+        let escaping = match escaping.as_str() {
+            "html" | "htm" | "xml" => "::askama::Html",
+            "txt" | "none" => "::askama::Text",
+            val => panic!("unknown value '{}' for escape mode", val),
+        };
+
         TemplateInput {
             ast,
             config,
             source,
             print,
-            escaping: escaping.unwrap_or_else(|| EscapeMode::from_path(&path)),
+            escaping,
             ext,
             parent,
             path,
@@ -184,34 +196,6 @@ pub enum Source {
     Source(String),
 }
 
-#[derive(PartialEq)]
-pub enum EscapeMode {
-    Html,
-    None,
-}
-
-impl From<String> for EscapeMode {
-    fn from(s: String) -> EscapeMode {
-        use self::EscapeMode::*;
-        match s.as_ref() {
-            "html" => Html,
-            "none" => None,
-            v => panic!("invalid value for escape option: {}", v),
-        }
-    }
-}
-
-impl EscapeMode {
-    fn from_path(path: &PathBuf) -> EscapeMode {
-        let extension = path.extension().map(|s| s.to_str().unwrap()).unwrap_or("");
-        if HTML_EXTENSIONS.contains(&extension) {
-            EscapeMode::Html
-        } else {
-            EscapeMode::None
-        }
-    }
-}
-
 #[derive(PartialEq)]
 pub enum Print {
     All,
@@ -232,5 +216,3 @@ impl From<String> for Print {
         }
     }
 }
-
-const HTML_EXTENSIONS: [&str; 3] = ["html", "htm", "xml"];
diff --git a/askama_escape/benches/all.rs b/askama_escape/benches/all.rs
index af28c43..a98f2d7 100644
--- a/askama_escape/benches/all.rs
+++ b/askama_escape/benches/all.rs
@@ -1,7 +1,7 @@
 #[macro_use]
 extern crate criterion;
 
-use askama_escape::MarkupDisplay;
+use askama_escape::{Html, MarkupDisplay};
 use criterion::Criterion;
 
 criterion_main!(benches);
@@ -68,10 +68,10 @@ quis lacus at, gravida maximus elit. Duis tristique, nisl nullam.
     "#;
 
     b.iter(|| {
-        format!("{}", MarkupDisplay::from(string_long));
-        format!("{}", MarkupDisplay::from(string_short));
-        format!("{}", MarkupDisplay::from(empty));
-        format!("{}", MarkupDisplay::from(no_escape));
-        format!("{}", MarkupDisplay::from(no_escape_long));
+        format!("{}", MarkupDisplay::new_unsafe(string_long, Html));
+        format!("{}", MarkupDisplay::new_unsafe(string_short, Html));
+        format!("{}", MarkupDisplay::new_unsafe(empty, Html));
+        format!("{}", MarkupDisplay::new_unsafe(no_escape, Html));
+        format!("{}", MarkupDisplay::new_unsafe(no_escape_long, Html));
     });
 }
diff --git a/askama_escape/src/lib.rs b/askama_escape/src/lib.rs
index 48d43ca..01da4ed 100644
--- a/askama_escape/src/lib.rs
+++ b/askama_escape/src/lib.rs
@@ -2,58 +2,76 @@ use std::fmt::{self, Display, Formatter};
 use std::io::{self, prelude::*};
 use std::str;
 
-#[derive(Debug, PartialEq)]
-pub enum MarkupDisplay<T>
+pub struct MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    Safe(T),
-    Unsafe(T),
+    value: DisplayValue<T>,
+    escaper: E,
 }
 
-impl<T> MarkupDisplay<T>
+impl<E, T> MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    pub fn mark_safe(self) -> MarkupDisplay<T> {
-        match self {
-            MarkupDisplay::Unsafe(t) => MarkupDisplay::Safe(t),
-            _ => self,
+    pub fn new_unsafe(value: T, escaper: E) -> Self {
+        Self {
+            value: DisplayValue::Unsafe(value),
+            escaper,
         }
     }
-}
 
-impl<T> From<T> for MarkupDisplay<T>
-where
-    T: Display,
-{
-    fn from(t: T) -> MarkupDisplay<T> {
-        MarkupDisplay::Unsafe(t)
+    pub fn new_safe(value: T, escaper: E) -> Self {
+        Self {
+            value: DisplayValue::Safe(value),
+            escaper,
+        }
+    }
+
+    pub fn mark_safe(mut self) -> MarkupDisplay<E, T> {
+        self.value = match self.value {
+            DisplayValue::Unsafe(t) => DisplayValue::Safe(t),
+            _ => self.value,
+        };
+        self
     }
 }
 
-impl<T> Display for MarkupDisplay<T>
+impl<E, T> Display for MarkupDisplay<E, T>
 where
+    E: Escaper,
     T: Display,
 {
-    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
-        match *self {
-            MarkupDisplay::Unsafe(ref t) => {
-                let mut w = EscapeWriter { fmt: f };
-                write!(w, "{}", t).map_err(|_e| fmt::Error)
-            }
-            MarkupDisplay::Safe(ref t) => t.fmt(f),
+    fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
+        match self.value {
+            DisplayValue::Unsafe(ref t) => write!(
+                EscapeWriter {
+                    fmt,
+                    escaper: &self.escaper
+                },
+                "{}",
+                t
+            )
+            .map_err(|_| fmt::Error),
+            DisplayValue::Safe(ref t) => t.fmt(fmt),
         }
     }
 }
 
-pub struct EscapeWriter<'a, 'b: 'a> {
+pub struct EscapeWriter<'a, 'b: 'a, E> {
     fmt: &'a mut fmt::Formatter<'b>,
+    escaper: &'a E,
 }
 
-impl io::Write for EscapeWriter<'_, '_> {
+impl<E> io::Write for EscapeWriter<'_, '_, E>
+where
+    E: Escaper,
+{
     fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
-        write_escaped_str(self.fmt, bytes)
+        self.escaper
+            .write_escaped_bytes(self.fmt, bytes)
             .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
         Ok(bytes.len())
     }
@@ -63,12 +81,35 @@ impl io::Write for EscapeWriter<'_, '_> {
     }
 }
 
-pub fn escape(s: &str) -> Escaped<'_> {
+pub fn escape<E>(s: &str, escaper: E) -> Escaped<'_, E>
+where
+    E: Escaper,
+{
     Escaped {
         bytes: s.as_bytes(),
+        escaper,
+    }
+}
+
+pub struct Escaped<'a, E>
+where
+    E: Escaper,
+{
+    bytes: &'a [u8],
+    escaper: E,
+}
+
+impl<'a, E> ::std::fmt::Display for Escaped<'a, E>
+where
+    E: Escaper,
+{
+    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.escaper.write_escaped_bytes(fmt, self.bytes)
     }
 }
 
+pub struct Html;
+
 macro_rules! escaping_body {
     ($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{
         if $start < $i {
@@ -79,33 +120,45 @@ macro_rules! escaping_body {
     }};
 }
 
-pub struct Escaped<'a> {
-    bytes: &'a [u8],
+impl Escaper for Html {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+        let mut start = 0;
+        for (i, b) in bytes.iter().enumerate() {
+            if b.wrapping_sub(b'"') <= FLAG {
+                match *b {
+                    b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
+                    b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
+                    b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
+                    b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
+                    b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
+                    b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
+                    _ => (),
+                }
+            }
+        }
+        fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })
+    }
 }
 
-impl<'a> ::std::fmt::Display for Escaped<'a> {
-    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write_escaped_str(fmt, self.bytes)
+pub struct Text;
+
+impl Escaper for Text {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+        fmt.write_str(unsafe { str::from_utf8_unchecked(bytes) })
     }
 }
 
-fn write_escaped_str(fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
-    let mut start = 0;
-    for (i, b) in bytes.iter().enumerate() {
-        if b.wrapping_sub(b'"') <= FLAG {
-            match *b {
-                b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
-                b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
-                b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
-                b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
-                b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
-                b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
-                _ => (),
-            }
-        }
-    }
-    fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })?;
-    Ok(())
+#[derive(Debug, PartialEq)]
+enum DisplayValue<T>
+where
+    T: Display,
+{
+    Safe(T),
+    Unsafe(T),
+}
+
+pub trait Escaper {
+    fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result;
 }
 
 const FLAG: u8 = b'>' - b'"';
@@ -115,10 +168,10 @@ mod tests {
     use super::*;
     #[test]
     fn test_escape() {
-        assert_eq!(escape("").to_string(), "");
-        assert_eq!(escape("<&>").to_string(), "&lt;&amp;&gt;");
-        assert_eq!(escape("bla&").to_string(), "bla&amp;");
-        assert_eq!(escape("<foo").to_string(), "&lt;foo");
-        assert_eq!(escape("bla&h").to_string(), "bla&amp;h");
+        assert_eq!(escape("", Html).to_string(), "");
+        assert_eq!(escape("<&>", Html).to_string(), "&lt;&amp;&gt;");
+        assert_eq!(escape("bla&", Html).to_string(), "bla&amp;");
+        assert_eq!(escape("<foo", Html).to_string(), "&lt;foo");
+        assert_eq!(escape("bla&h", Html).to_string(), "bla&amp;h");
     }
 }
diff --git a/askama_shared/src/filters/json.rs b/askama_shared/src/filters/json.rs
index 8cdd4e6..ba7e61b 100644
--- a/askama_shared/src/filters/json.rs
+++ b/askama_shared/src/filters/json.rs
@@ -1,5 +1,5 @@
 use crate::error::{Error, Result};
-use askama_escape::MarkupDisplay;
+use askama_escape::{Escaper, MarkupDisplay};
 use serde::Serialize;
 
 /// Serialize to JSON (requires `serde_json` feature)
@@ -8,9 +8,9 @@ use serde::Serialize;
 ///
 /// This will panic if `S`'s implementation of `Serialize` decides to fail,
 /// or if `T` contains a map with non-string keys.
-pub fn json<S: Serialize>(s: &S) -> Result<MarkupDisplay<String>> {
+pub fn json<E: Escaper, S: Serialize>(e: E, s: &S) -> Result<MarkupDisplay<E, String>> {
     match serde_json::to_string_pretty(s) {
-        Ok(s) => Ok(MarkupDisplay::Safe(s)),
+        Ok(s) => Ok(MarkupDisplay::new_safe(s, e)),
         Err(e) => Err(Error::from(e)),
     }
 }
@@ -18,13 +18,14 @@ pub fn json<S: Serialize>(s: &S) -> Result<MarkupDisplay<String>> {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use askama_escape::Html;
 
     #[test]
     fn test_json() {
-        assert_eq!(json(&true).unwrap().to_string(), "true");
-        assert_eq!(json(&"foo").unwrap().to_string(), r#""foo""#);
+        assert_eq!(json(Html, &true).unwrap().to_string(), "true");
+        assert_eq!(json(Html, &"foo").unwrap().to_string(), r#""foo""#);
         assert_eq!(
-            json(&vec!["foo", "bar"]).unwrap().to_string(),
+            json(Html, &vec!["foo", "bar"]).unwrap().to_string(),
             r#"[
   "foo",
   "bar"
diff --git a/askama_shared/src/filters/mod.rs b/askama_shared/src/filters/mod.rs
index 85ff8b2..ea702db 100644
--- a/askama_shared/src/filters/mod.rs
+++ b/askama_shared/src/filters/mod.rs
@@ -11,7 +11,7 @@ mod json;
 pub use self::json::json;
 
 use crate::error::Error::Fmt;
-use askama_escape::MarkupDisplay;
+use askama_escape::{Escaper, MarkupDisplay};
 use humansize::{file_size_opts, FileSize};
 use num_traits::cast::NumCast;
 use num_traits::Signed;
@@ -52,31 +52,38 @@ pub const BUILT_IN_FILTERS: [&str; 22] = [
 ///
 /// Use this is you want to allow markup in an expression, or if you know
 /// that the expression's contents don't need to be escaped.
-pub fn safe<D, I>(v: I) -> Result<MarkupDisplay<D>>
+///
+/// Askama will automatically insert the first (`Escaper`) argument,
+/// so this filter only takes a single argument of any type that implements
+/// `Display`.
+pub fn safe<E, T>(e: E, v: T) -> Result<MarkupDisplay<E, T>>
 where
-    D: fmt::Display,
-    MarkupDisplay<D>: From<I>,
+    E: Escaper,
+    T: fmt::Display,
 {
-    let res: MarkupDisplay<D> = v.into();
-    Ok(res.mark_safe())
+    Ok(MarkupDisplay::new_safe(v, e))
 }
 
 /// Escapes `&`, `<` and `>` in strings
-pub fn escape<D, I>(i: I) -> Result<MarkupDisplay<D>>
+///
+/// Askama will automatically insert the first (`Escaper`) argument,
+/// so this filter only takes a single argument of any type that implements
+/// `Display`.
+pub fn escape<E, T>(e: E, v: T) -> Result<MarkupDisplay<E, T>>
 where
-    D: fmt::Display,
-    MarkupDisplay<D>: From<I>,
+    E: Escaper,
+    T: fmt::Display,
 {
-    Ok(i.into())
+    Ok(MarkupDisplay::new_unsafe(v, e))
 }
 
 /// Alias for the `escape()` filter
-pub fn e<D, I>(i: I) -> Result<MarkupDisplay<D>>
+pub fn e<E, T>(e: E, v: T) -> Result<MarkupDisplay<E, T>>
 where
-    D: fmt::Display,
-    MarkupDisplay<D>: From<I>,
+    E: Escaper,
+    T: fmt::Display,
 {
-    escape(i)
+    escape(e, v)
 }
 
 /// Returns adequate string representation (in KB, ..) of number of bytes
-- 
cgit