From 57dc7ef9ca6f093908ad0eabbf81476293e218ed Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Mon, 5 Nov 2018 13:24:15 +0100 Subject: Escape into Formatter --- askama_shared/src/escaping.rs | 103 ++++++++++++++++++++++-------------------- 1 file changed, 53 insertions(+), 50 deletions(-) (limited to 'askama_shared/src/escaping.rs') diff --git a/askama_shared/src/escaping.rs b/askama_shared/src/escaping.rs index 930dbbe..8bb8f0b 100644 --- a/askama_shared/src/escaping.rs +++ b/askama_shared/src/escaping.rs @@ -1,4 +1,5 @@ use std::fmt::{self, Display, Formatter}; +use std::str; #[derive(Debug, PartialEq)] pub enum MarkupDisplay @@ -19,11 +20,6 @@ where _ => self, } } - pub fn unsafe_string(&self) -> String { - match *self { - MarkupDisplay::Safe(ref t) | MarkupDisplay::Unsafe(ref t) => format!("{}", t), - } - } } impl From for MarkupDisplay @@ -41,58 +37,65 @@ where { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match *self { - MarkupDisplay::Unsafe(_) => write!(f, "{}", escape(self.unsafe_string())), + MarkupDisplay::Unsafe(ref t) => escape(&t.to_string()).fmt(f), MarkupDisplay::Safe(ref t) => t.fmt(f), } } } const FLAG: u8 = b'>' - b'"'; -pub fn escape(s: String) -> String { - let mut found = None; - for (i, b) in s.as_bytes().iter().enumerate() { - if b.wrapping_sub(b'"') <= FLAG { - match *b { - b'<' | b'>' | b'&' | b'"' | b'\'' | b'/' => { - found = Some(i); - break; - } - _ => (), - }; - } + +pub fn escape(s: &str) -> Escaped { + Escaped { + bytes: s.as_bytes(), } +} - if let Some(found) = found { - let bytes = s.as_bytes(); - let mut res = Vec::with_capacity(s.len() + 6); - res.extend(&bytes[0..found]); - for c in bytes[found..].iter() { - match *c { - b'<' => { - res.extend(b"<"); - } - b'>' => { - res.extend(b">"); - } - b'&' => { - res.extend(b"&"); - } - b'"' => { - res.extend(b"""); +pub struct Escaped<'a> { + bytes: &'a [u8], +} + +enum State { + Empty, + Unescaped(usize), +} + +impl<'a> ::std::fmt::Display for Escaped<'a> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use self::State::*; + let mut state = Empty; + for (i, b) in self.bytes.iter().enumerate() { + let next = if b.wrapping_sub(b'"') <= FLAG { + match *b { + b'<' => Some("<"), + b'>' => Some(">"), + b'&' => Some("&"), + b'"' => Some("""), + b'\'' => Some("'"), + b'/' => Some("/"), + _ => None, } - b'\'' => { - res.extend(b"'"); + } else { + None + }; + state = match (state, next) { + (Empty, None) => Unescaped(i), + (s @ Unescaped(_), None) => s, + (Empty, Some(escaped)) => { + fmt.write_str(escaped)?; + Empty } - b'/' => { - res.extend(b"/"); + (Unescaped(start), Some(escaped)) => { + fmt.write_str(unsafe { str::from_utf8_unchecked(&self.bytes[start..i]) })?; + fmt.write_str(escaped)?; + Empty } - _ => res.push(*c), - } + }; } - - String::from_utf8(res).unwrap() - } else { - s + if let Unescaped(start) = state { + fmt.write_str(unsafe { str::from_utf8_unchecked(&self.bytes[start..]) })?; + } + Ok(()) } } @@ -101,10 +104,10 @@ mod tests { use super::*; #[test] fn test_escape() { - assert_eq!(escape("".to_string()), ""); - assert_eq!(escape("<&>".to_string()), "<&>"); - assert_eq!(escape("bla&".to_string()), "bla&"); - assert_eq!(escape("").to_string(), "<&>"); + assert_eq!(escape("bla&").to_string(), "bla&"); + assert_eq!(escape("