aboutsummaryrefslogtreecommitdiffstats
path: root/askama_escape/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'askama_escape/src/lib.rs')
-rw-r--r--askama_escape/src/lib.rs163
1 files changed, 108 insertions, 55 deletions
diff --git a/askama_escape/src/lib.rs b/askama_escape/src/lib.rs
index 48d43ca..01da4ed 100644
--- a/askama_escape/src/lib.rs
+++ b/askama_escape/src/lib.rs
@@ -2,58 +2,76 @@ use std::fmt::{self, Display, Formatter};
use std::io::{self, prelude::*};
use std::str;
-#[derive(Debug, PartialEq)]
-pub enum MarkupDisplay<T>
+pub struct MarkupDisplay<E, T>
where
+ E: Escaper,
T: Display,
{
- Safe(T),
- Unsafe(T),
+ value: DisplayValue<T>,
+ escaper: E,
}
-impl<T> MarkupDisplay<T>
+impl<E, T> MarkupDisplay<E, T>
where
+ E: Escaper,
T: Display,
{
- pub fn mark_safe(self) -> MarkupDisplay<T> {
- match self {
- MarkupDisplay::Unsafe(t) => MarkupDisplay::Safe(t),
- _ => self,
+ pub fn new_unsafe(value: T, escaper: E) -> Self {
+ Self {
+ value: DisplayValue::Unsafe(value),
+ escaper,
}
}
-}
-impl<T> From<T> for MarkupDisplay<T>
-where
- T: Display,
-{
- fn from(t: T) -> MarkupDisplay<T> {
- MarkupDisplay::Unsafe(t)
+ pub fn new_safe(value: T, escaper: E) -> Self {
+ Self {
+ value: DisplayValue::Safe(value),
+ escaper,
+ }
+ }
+
+ pub fn mark_safe(mut self) -> MarkupDisplay<E, T> {
+ self.value = match self.value {
+ DisplayValue::Unsafe(t) => DisplayValue::Safe(t),
+ _ => self.value,
+ };
+ self
}
}
-impl<T> Display for MarkupDisplay<T>
+impl<E, T> Display for MarkupDisplay<E, T>
where
+ E: Escaper,
T: Display,
{
- fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
- match *self {
- MarkupDisplay::Unsafe(ref t) => {
- let mut w = EscapeWriter { fmt: f };
- write!(w, "{}", t).map_err(|_e| fmt::Error)
- }
- MarkupDisplay::Safe(ref t) => t.fmt(f),
+ fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
+ match self.value {
+ DisplayValue::Unsafe(ref t) => write!(
+ EscapeWriter {
+ fmt,
+ escaper: &self.escaper
+ },
+ "{}",
+ t
+ )
+ .map_err(|_| fmt::Error),
+ DisplayValue::Safe(ref t) => t.fmt(fmt),
}
}
}
-pub struct EscapeWriter<'a, 'b: 'a> {
+pub struct EscapeWriter<'a, 'b: 'a, E> {
fmt: &'a mut fmt::Formatter<'b>,
+ escaper: &'a E,
}
-impl io::Write for EscapeWriter<'_, '_> {
+impl<E> io::Write for EscapeWriter<'_, '_, E>
+where
+ E: Escaper,
+{
fn write(&mut self, bytes: &[u8]) -> io::Result<usize> {
- write_escaped_str(self.fmt, bytes)
+ self.escaper
+ .write_escaped_bytes(self.fmt, bytes)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
Ok(bytes.len())
}
@@ -63,12 +81,35 @@ impl io::Write for EscapeWriter<'_, '_> {
}
}
-pub fn escape(s: &str) -> Escaped<'_> {
+pub fn escape<E>(s: &str, escaper: E) -> Escaped<'_, E>
+where
+ E: Escaper,
+{
Escaped {
bytes: s.as_bytes(),
+ escaper,
+ }
+}
+
+pub struct Escaped<'a, E>
+where
+ E: Escaper,
+{
+ bytes: &'a [u8],
+ escaper: E,
+}
+
+impl<'a, E> ::std::fmt::Display for Escaped<'a, E>
+where
+ E: Escaper,
+{
+ fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.escaper.write_escaped_bytes(fmt, self.bytes)
}
}
+pub struct Html;
+
macro_rules! escaping_body {
($start:ident, $i:ident, $fmt:ident, $bytes:ident, $quote:expr) => {{
if $start < $i {
@@ -79,33 +120,45 @@ macro_rules! escaping_body {
}};
}
-pub struct Escaped<'a> {
- bytes: &'a [u8],
+impl Escaper for Html {
+ fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+ let mut start = 0;
+ for (i, b) in bytes.iter().enumerate() {
+ if b.wrapping_sub(b'"') <= FLAG {
+ match *b {
+ b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
+ b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
+ b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
+ b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
+ b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
+ b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
+ _ => (),
+ }
+ }
+ }
+ fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })
+ }
}
-impl<'a> ::std::fmt::Display for Escaped<'a> {
- fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
- write_escaped_str(fmt, self.bytes)
+pub struct Text;
+
+impl Escaper for Text {
+ fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
+ fmt.write_str(unsafe { str::from_utf8_unchecked(bytes) })
}
}
-fn write_escaped_str(fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result {
- let mut start = 0;
- for (i, b) in bytes.iter().enumerate() {
- if b.wrapping_sub(b'"') <= FLAG {
- match *b {
- b'<' => escaping_body!(start, i, fmt, bytes, "&lt;"),
- b'>' => escaping_body!(start, i, fmt, bytes, "&gt;"),
- b'&' => escaping_body!(start, i, fmt, bytes, "&amp;"),
- b'"' => escaping_body!(start, i, fmt, bytes, "&quot;"),
- b'\'' => escaping_body!(start, i, fmt, bytes, "&#x27;"),
- b'/' => escaping_body!(start, i, fmt, bytes, "&#x2f;"),
- _ => (),
- }
- }
- }
- fmt.write_str(unsafe { str::from_utf8_unchecked(&bytes[start..]) })?;
- Ok(())
+#[derive(Debug, PartialEq)]
+enum DisplayValue<T>
+where
+ T: Display,
+{
+ Safe(T),
+ Unsafe(T),
+}
+
+pub trait Escaper {
+ fn write_escaped_bytes(&self, fmt: &mut fmt::Formatter<'_>, bytes: &[u8]) -> fmt::Result;
}
const FLAG: u8 = b'>' - b'"';
@@ -115,10 +168,10 @@ mod tests {
use super::*;
#[test]
fn test_escape() {
- assert_eq!(escape("").to_string(), "");
- assert_eq!(escape("<&>").to_string(), "&lt;&amp;&gt;");
- assert_eq!(escape("bla&").to_string(), "bla&amp;");
- assert_eq!(escape("<foo").to_string(), "&lt;foo");
- assert_eq!(escape("bla&h").to_string(), "bla&amp;h");
+ assert_eq!(escape("", Html).to_string(), "");
+ assert_eq!(escape("<&>", Html).to_string(), "&lt;&amp;&gt;");
+ assert_eq!(escape("bla&", Html).to_string(), "bla&amp;");
+ assert_eq!(escape("<foo", Html).to_string(), "&lt;foo");
+ assert_eq!(escape("bla&h", Html).to_string(), "bla&amp;h");
}
}