diff options
Diffstat (limited to 'examples/gallery/src/civitai.rs')
-rw-r--r-- | examples/gallery/src/civitai.rs | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/examples/gallery/src/civitai.rs b/examples/gallery/src/civitai.rs new file mode 100644 index 00000000..d1163013 --- /dev/null +++ b/examples/gallery/src/civitai.rs @@ -0,0 +1,148 @@ +use bytes::Bytes; +use serde::Deserialize; +use tokio::task; + +use std::fmt; +use std::io; +use std::sync::Arc; + +#[derive(Debug, Clone, Deserialize)] +pub struct Image { + pub id: Id, + url: String, +} + +impl Image { + pub async fn list() -> Result<Vec<Self>, Error> { + let client = reqwest::Client::new(); + + #[derive(Deserialize)] + struct Response { + items: Vec<Image>, + } + + let response: Response = client + .get(endpoint("/images")) + .query(&[ + ("sort", "Most Reactions"), + ("period", "Day"), + ("nsfw", "None"), + ("limit", "99"), + ]) + .send() + .await? + .error_for_status()? + .json() + .await?; + + Ok(response.items) + } + + pub async fn download(self, size: Size) -> Result<Rgba, Error> { + let client = reqwest::Client::new(); + + let bytes = client + .get(match size { + Size::Original => self.url, + Size::Thumbnail => self + .url + .split("/") + .map(|part| { + if part.starts_with("width=") { + "width=640" + } else { + part + } + }) + .collect::<Vec<_>>() + .join("/"), + }) + .send() + .await? + .error_for_status()? + .bytes() + .await?; + + let image = task::spawn_blocking(move || { + Ok::<_, Error>( + image::ImageReader::new(io::Cursor::new(bytes)) + .with_guessed_format()? + .decode()? + .to_rgba8(), + ) + }) + .await??; + + Ok(Rgba { + width: image.width(), + height: image.height(), + pixels: Bytes::from(image.into_raw()), + }) + } +} + +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, +)] +pub struct Id(u32); + +#[derive(Clone)] +pub struct Rgba { + pub width: u32, + pub height: u32, + pub pixels: Bytes, +} + +impl fmt::Debug for Rgba { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Rgba") + .field("width", &self.width) + .field("height", &self.height) + .finish() + } +} + +#[derive(Debug, Clone, Copy)] +pub enum Size { + Original, + Thumbnail, +} + +fn endpoint(path: &str) -> String { + const API_URL: &str = "https://civitai.com/api/v1"; + + format!("{API_URL}{path}") +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub enum Error { + RequestFailed(Arc<reqwest::Error>), + IOFailed(Arc<io::Error>), + JoinFailed(Arc<task::JoinError>), + ImageDecodingFailed(Arc<image::ImageError>), +} + +impl From<reqwest::Error> for Error { + fn from(error: reqwest::Error) -> Self { + Self::RequestFailed(Arc::new(error)) + } +} + +impl From<io::Error> for Error { + fn from(error: io::Error) -> Self { + Self::IOFailed(Arc::new(error)) + } +} + +impl From<task::JoinError> for Error { + fn from(error: task::JoinError) -> Self { + Self::JoinFailed(Arc::new(error)) + } +} + +impl From<image::ImageError> for Error { + fn from(error: image::ImageError) -> Self { + Self::ImageDecodingFailed(Arc::new(error)) + } +} |