summaryrefslogblamecommitdiffstats
path: root/src/main.rs
blob: 2736c5a1976caa53a2b9c0eb3a1abdb8ddbd1247 (plain) (tree)
1
2
3
4
5
6
7
8
9
10




                                      

                                                                                     
                                                        

                                   
                                    
                                      

                                                                 
                             


                                                                                         
                                    

























                                                                




                                               







                                                                            
                                  
















                                 





                                                                                        



                                       
                                                                           













                                                            
                       




                                                                
















































                                                                                                  



                                      
                                                                          





                      
                               


                 
























































                                                                                           





















                                                        




















                                                                        




                                                                       
                          






















































                                                                                         

                       


           











                                                






                                                 
                                             


















                                                                        
#[macro_use]
mod actix_ructe;

use std::time::{Duration, SystemTime};

use actix_session::storage::CookieSessionStore;
use actix_session::{Session, SessionGetError, SessionInsertError, SessionMiddleware};
use actix_web::body::{BoxBody, EitherBody, MessageBody};
use actix_web::cookie::time::Error;
use actix_web::cookie::Key;
use actix_web::dev::ServiceResponse;
use actix_web::http::header::LOCATION;
use actix_web::http::{header, StatusCode};
use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers};
use actix_web::web::Redirect;
use actix_web::{get, post, web, App, HttpResponse, HttpServer, Responder, ResponseError};
use bcrypt::{hash, verify, DEFAULT_COST};
use serde::Deserialize;
use sqlx::postgres::PgDatabaseError;
use sqlx::{postgres::PgPoolOptions, Pool, Postgres};
use templates::statics::StaticFile;

static FAR: Duration = Duration::from_secs(180 * 24 * 60 * 60);

type Result<T> = std::result::Result<T, PinussyError>;

#[derive(Clone)]
struct Pinussy {
    db: Pool<Postgres>,
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let pool = PgPoolOptions::new()
        .max_connections(5)
        .connect("postgres://pinussy:pinussy@localhost/pinussy")
        .await
        .unwrap();

    sqlx::migrate!("./migrations").run(&pool).await.unwrap();

    let pinussy = Pinussy { db: pool };

    HttpServer::new(move || {
        App::new()
            .wrap(SessionMiddleware::new(
                // TODO: postgres session store
                CookieSessionStore::default(),
                Key::generate(),
            ))
            .wrap(
                ErrorHandlers::new()
                    .handler(StatusCode::NOT_FOUND, render_404)
                    .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
            )
            .app_data(web::Data::new(pinussy.clone()))
            .service(web::resource("/static/{filename}").to(static_file))
            .service(home)
            // .service(home_auth)
            .service(get_login)
            .service(post_login)
            .service(get_signup)
            .service(post_signup)
            .service(get_users)
        // .service(get_pins)
        // .service(post_pin)
        // .service(get_pin)
        // .service(post_board)
        // .service(get_board)
    })
    .bind(("127.0.0.1", 8080))?
    .run()
    .await
}

#[get("/")]
async fn home(session: Session) -> Result<HttpResponse> {
    if let Some(user_id) = session.get::<i32>("user_id")? {
        return Ok(HttpResponse::Ok().body(format!("you are logged in as {}", user_id)));
    } else {
        return Ok(HttpResponse::Ok().body("Hello world!"));
    }
}

#[get("/signup")]
async fn get_signup() -> HttpResponse {
    HttpResponse::Ok().body(render!(templates::signup_html, None).unwrap())
}

#[derive(Deserialize)]
struct SignupForm {
    username: String,
    password: String,
}

#[post("/signup")]
async fn post_signup(
    state: web::Data<Pinussy>,
    form: web::Form<SignupForm>,
) -> Result<HttpResponse> {
    let password_hash = hash(&form.password, DEFAULT_COST)?;
    match sqlx::query!(
        "insert into users(username, password) values ($1, $2)",
        &form.username,
        password_hash
    )
    .execute(&state.db)
    .await
    {
        Ok(_) => {
            return Ok(HttpResponse::Ok().body(
                render!(
                    templates::signup_html,
                    Some(Notification {
                        kind: NotificationKind::Info,
                        message: format!("you have successfully registered as {}", &form.username)
                    })
                )
                .unwrap(),
            ))
        }
        Err(e) => {
            match e {
                sqlx::Error::Database(e) => {
                    if e.is_unique_violation() {
                        return Ok(HttpResponse::Conflict().body(
                            render!(
                                templates::signup_html,
                                Some(Notification {
                                    kind: NotificationKind::Error,
                                    message: format!(
                                        "error: the username \"{}\" already exists",
                                        &form.username
                                    )
                                })
                            )
                            .unwrap(),
                        ));
                    }
                }
                // TODO: log error
                _ => {}
            }
            return Ok(HttpResponse::InternalServerError().body(
                render!(
                    templates::signup_html,
                    Some(Notification {
                        kind: NotificationKind::Error,
                        message: "there was an internal server error. please try again later."
                            .to_owned()
                    })
                )
                .unwrap(),
            ));
        }
    };
}

#[get("/login")]
async fn get_login() -> HttpResponse {
    HttpResponse::Ok().body(render!(templates::login_html, None).unwrap())
}

#[derive(Deserialize)]
struct LoginForm {
    username: String,
    password: String,
    rememberme: Option<String>,
}

#[post("/login")]
async fn post_login(
    state: web::Data<Pinussy>,
    session: Session,
    form: web::Form<SignupForm>,
) -> Result<HttpResponse> {
    match sqlx::query!(
        "select id, password from users where username = $1",
        &form.username
    )
    .fetch_one(&state.db)
    .await
    {
        Ok(user) => {
            let password_hash: String = user.password;
            if verify(&form.password, &password_hash)? {
                session.insert("user_id", user.id)?;
                return Ok(HttpResponse::SeeOther()
                    .insert_header((LOCATION, "/"))
                    .finish());
            } else {
                return Ok(HttpResponse::Unauthorized().body(
                    render!(
                        templates::login_html,
                        Some(Notification {
                            kind: NotificationKind::Error,
                            message: "that password is incorrect".to_owned()
                        })
                    )
                    .unwrap(),
                ));
            }
        }
        Err(sqlx::Error::RowNotFound) => {
            return Ok(HttpResponse::NotFound().body(
                render!(
                    templates::login_html,
                    Some(Notification {
                        kind: NotificationKind::Error,
                        message: format!("the user \"{}\" does not exist", &form.username)
                    })
                )
                .unwrap(),
            ));
        }
        Err(_) => {
            return Ok(HttpResponse::InternalServerError().body(
                render!(
                    templates::login_html,
                    Some(Notification {
                        kind: NotificationKind::Error,
                        message: "internal server error. please try again later".to_owned()
                    })
                )
                .unwrap(),
            ));
        }
    }
}

#[derive(sqlx::Type)]
#[sqlx(type_name = "privacy", rename_all = "lowercase")]
enum Privacy {
    Private,
    Unlisted,
    Public,
}

#[derive(sqlx::FromRow)]
pub struct User {
    id: i32,
    username: String,
    password: String,
    email: Option<String>,
    bio: Option<String>,
    site: Option<String>,
    privacy: Privacy,
    admin: bool,
}

pub enum NotificationKind {
    Info,
    Warning,
    Error,
}

impl std::fmt::Display for NotificationKind {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            NotificationKind::Info => f.write_str("info"),
            NotificationKind::Warning => f.write_str("warning"),
            NotificationKind::Error => f.write_str("error"),
        }
    }
}

pub struct Notification {
    kind: NotificationKind,
    message: String,
}

#[get("/users")]
async fn get_users(state: web::Data<Pinussy>) -> Result<HttpResponse> {
    let users: Vec<User> = sqlx::query_as("select * from users")
        .fetch_all(&state.db)
        .await
        // TODO: no unwrap
        .unwrap();
    println!("lol");
    Ok(HttpResponse::Ok().body(render!(templates::users_html, users).unwrap()))
}

async fn static_file(path: web::Path<String>) -> HttpResponse {
    let name = &path.into_inner();
    if let Some(data) = StaticFile::get(name) {
        let far_expires = SystemTime::now() + FAR;
        HttpResponse::Ok()
            .insert_header(header::Expires(far_expires.into()))
            .insert_header(header::ContentType(data.mime.clone()))
            .body(data.content)
    } else {
        HttpResponse::NotFound()
            .reason("No such static file.")
            .finish()
    }
}

fn render_404(res: ServiceResponse) -> actix_web::Result<ErrorHandlerResponse<BoxBody>> {
    Ok(error_response(
        res,
        StatusCode::NOT_FOUND,
        "The resource you requested can't be found.",
    ))
}

fn render_500(res: ServiceResponse) -> actix_web::Result<ErrorHandlerResponse<BoxBody>> {
    Ok(error_response(
        res,
        StatusCode::INTERNAL_SERVER_ERROR,
        "Sorry, Something went wrong.  This is probably not your fault.",
    ))
}

fn error_response(
    mut res: ServiceResponse,
    status_code: StatusCode,
    message: &str,
) -> ErrorHandlerResponse<BoxBody> {
    res.headers_mut().insert(
        header::CONTENT_TYPE,
        header::HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()),
    );
    ErrorHandlerResponse::Response(res.map_body(|_head, _body| {
        EitherBody::right(MessageBody::boxed(
            render!(templates::error_html, status_code, message).unwrap(),
        ))
    }))
}

#[derive(Debug)]
enum PinussyError {
    Database(sqlx::Error),
    SessionInsertError,
    SessionGetError,
    Bcrypt,
}

impl From<SessionInsertError> for PinussyError {
    fn from(_: SessionInsertError) -> Self {
        Self::SessionInsertError
    }
}

impl From<SessionGetError> for PinussyError {
    fn from(_: SessionGetError) -> Self {
        Self::SessionGetError
    }
}

impl From<sqlx::Error> for PinussyError {
    fn from(e: sqlx::Error) -> Self {
        Self::Database(e)
    }
}

impl From<bcrypt::BcryptError> for PinussyError {
    fn from(_: bcrypt::BcryptError) -> Self {
        Self::Bcrypt
    }
}

impl std::fmt::Display for PinussyError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?}", self)
    }
}

impl std::error::Error for PinussyError {}

impl ResponseError for PinussyError {
    fn error_response(&self) -> HttpResponse<BoxBody> {
        HttpResponse::new(self.status_code())
    }
}

include!(concat!(env!("OUT_DIR"), "/templates.rs"));