use std::sync::Arc; use askama::Template; use axum::{ body::Body, extract::{Form, FromRequest, Path, Query, RequestParts}, http::{header, StatusCode}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, get_service, post}, Extension, Json, }; use axum_extra::extract::{cookie::Cookie, CookieJar}; use db::{Database, Db, Party, Ticket, User}; use rand::Rng; use serde::{Deserialize, Serialize}; use sqlx::postgres::PgPoolOptions; use tower_http::services::ServeDir; mod db; #[axum::async_trait] impl FromRequest for User { /// If the extractor fails it'll use this "rejection" type. A rejection is /// a kind of error that can be converted into a response. type Rejection = Redirect; /// Perform the extraction. async fn from_request(req: &mut RequestParts) -> Result { let jar = req .extract::() .await .or_else(|_| Err(Redirect::to("/login")))?; let token = jar.get("token").ok_or_else(|| Redirect::to("/login"))?; let db = req .extensions() .get::() .ok_or_else(|| Redirect::to("/login"))? .clone(); if let Ok(uuid) = uuid::Uuid::parse_str(&token.value()) { let user = db.get_user_token_by_token(uuid).await; if let Some(user) = user { if let Some(user) = db.get_user_by_token(user).await { Ok(user) } else { Err(Redirect::to("/login")) } } else { Err(Redirect::to("/login")) } } else { Err(Redirect::to("/login")) } } } #[derive(Template)] #[template(path = "index.html")] struct IndexPage { user: User, parties: Vec, } async fn index(Extension(pool): Extension, user: User) -> impl IntoResponse { let parties = pool.get_all_parties().await; let page = IndexPage { user, parties }; Html(page.render().unwrap()) } #[derive(Template)] #[template(path = "login.html")] struct LoginPage { not_found: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] struct LoginQuery { not_found: Option, } async fn login_page(Query(query): Query) -> impl IntoResponse { Html( LoginPage { not_found: query.not_found.unwrap_or(false), } .render() .unwrap(), ) } #[derive(Debug, Clone, Serialize, Deserialize)] struct LoginForm { username: String, password: String, } async fn login( Form(user): Form, Extension(pool): Extension, jar: CookieJar, ) -> impl IntoResponse { let u = pool.get_user_by_username(&user.username).await; if let Some(u) = u { if u.verify_password(&user.password) { let token = pool .add_new_user_token(&u, chrono::Utc::now() + chrono::Duration::weeks(5)) .await; let cookie = token.to_cookie(); (jar.add(cookie), Redirect::to("/")) } else { (jar, Redirect::to("/login?not_found=true")) } } else { (jar, Redirect::to("/login?not_found=true")) } } #[derive(Debug, Clone, Serialize, Deserialize)] struct ScanRequest { code: String, check: bool, party: i32, } #[derive(Debug, Clone, Serialize, Deserialize)] enum ScanState { NotFound, Found, Added, AlreadyScanned, } #[derive(Debug, Clone, Serialize, Deserialize)] struct ScanResponse { code: String, state: ScanState, } async fn scan_card( _: User, Extension(pool): Extension, Json(scan): Json, ) -> impl IntoResponse { let scanned = pool.is_student_in_party(scan.party, &scan.code).await; if scan.check { if let Some(mut ticket) = scanned { if ticket.inside { Json(ScanResponse { code: scan.code, state: ScanState::AlreadyScanned, }) } else { ticket.inside = true; pool.update_ticket(&ticket).await; Json(ScanResponse { code: scan.code, state: ScanState::Found, }) } } else { Json(ScanResponse { code: scan.code, state: ScanState::NotFound, }) } } else { if let Some(_) = scanned { Json(ScanResponse { code: scan.code, state: ScanState::AlreadyScanned, }) } else { pool.add_ticket(scan.party, &scan.code).await; Json(ScanResponse { code: scan.code, state: ScanState::Added, }) } } } #[derive(Debug, Clone, Serialize, Deserialize)] struct TicketsQuery { party: i32, } async fn get_tickets( _: User, Extension(pool): Extension, Query(q): Query, ) -> impl IntoResponse { let tickets = pool.get_all_tickets_for_party(q.party).await; Json(tickets) } async fn get_parties(_: User, Extension(pool): Extension) -> impl IntoResponse { let parties = pool.get_all_parties().await; Json(parties) } #[derive(Debug, Clone, Serialize, Deserialize)] struct CreatePartyParams { name: String, } async fn create_party( _: User, Extension(pool): Extension, Json(party): Json, ) -> impl IntoResponse { let party = pool.add_party(&party.name).await; Json(party) } async fn get_party_by_name( _: User, Extension(pool): Extension, Json(party): Json, ) -> impl IntoResponse { let party = pool.get_party_by_name(&party.name).await; Json(party) } #[derive(Template)] #[template(path = "party.html")] struct PartyPage { name: String, id: i32, } async fn party_page( _: User, Extension(pool): Extension, Path(party): Path, ) -> impl IntoResponse { let party = pool.get_party(party).await; if let Some(party) = party { let page = PartyPage { name: party.name, id: party.id, }; Html(page.render().unwrap()).into_response() } else { Redirect::to("/").into_response() } } async fn export_party( _: User, Extension(pool): Extension, Path(party): Path, ) -> impl IntoResponse { let tickets = pool.get_all_tickets_for_party(party).await; let mut csv = String::new(); csv.push_str("leerlingnummer,binnen\n"); for ticket in tickets { csv.push_str(&format!("{},{}\n", ticket.student, ticket.inside)); } let mut resp = Response::new(Body::from(csv)); resp.headers_mut().insert( header::CONTENT_DISPOSITION, header::HeaderValue::from_static("attachment; filename=tickets.csv"), ); resp.headers_mut().insert( header::CONTENT_TYPE, header::HeaderValue::from_static("text/csv"), ); resp } #[derive(Template)] #[template(path = "party_goers.html")] struct PartyGoersPage { id: i32, name: String, guests: Vec, } async fn party_goers( _: User, Extension(pool): Extension, Path(party): Path, ) -> impl IntoResponse { let party = pool.get_party(party).await; if let Some(party) = party { let guests = pool.get_all_tickets_for_party(party.id).await; let page = PartyGoersPage { id: party.id, name: party.name, guests, }; Html(page.render().unwrap()).into_response() } else { Redirect::to("/").into_response() } } #[derive(Template)] #[template(path = "password.html")] struct PasswordPage { wrong_password: bool, wrong_dupe_password: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] struct PasswordParams { wrong: Option, dupe: Option, } async fn change_password(Query(params): Query) -> impl IntoResponse { let page = PasswordPage { wrong_password: params.wrong.unwrap_or(false), wrong_dupe_password: params.dupe.unwrap_or(false), }; Html(page.render().unwrap()).into_response() } #[derive(Debug, Clone, Serialize, Deserialize)] struct ChangePasswordForm { password: String, password_new: String, password_new_again: String, } async fn change_password_post( mut user: User, Extension(pool): Extension, Form(form): Form, ) -> impl IntoResponse { if !user.verify_password(&form.password) { return Redirect::to("/password?wrong=true").into_response(); } if form.password_new != form.password_new_again { return Redirect::to("/password?dupe=true").into_response(); } user.update_password(&form.password_new); pool.save_user(&user).await; Redirect::to("/").into_response() } async fn logout(_: User, jar: CookieJar) -> impl IntoResponse { (jar.remove(Cookie::named("token")), Redirect::to("/")) } #[derive(Template)] #[template(path = "add_user.html")] struct AddUserPage {} async fn add_user_page(_: User) -> impl IntoResponse { Html(AddUserPage {}.render().unwrap()) } async fn add_user( _: User, Extension(pool): Extension, Form(form): Form, ) -> impl IntoResponse { pool.add_new_user(&form.username, &form.password).await; Redirect::to("/") } #[derive(Debug, Clone, Serialize, Deserialize)] struct TicketDeleteRequest { ticket_id: i32, } async fn remove_ticket( _: User, Extension(pool): Extension, Json(ticket): Json, ) -> impl IntoResponse { pool.remove_ticket(ticket.ticket_id).await; "OK" } #[tokio::main] async fn main() { let url = if let Some(url) = std::env::var("DATABASE_URL").ok() { url } else { "postgres://postgres:postgres@localhost:5432/postgres".to_string() }; let pool = PgPoolOptions::new() .connect(&url) .await .unwrap(); db::setup_db(&pool).await; let pool = Arc::new(Db::new(pool)); if let None = pool.get_user_by_username("admin").await { let password = if let Some(pass) = std::env::var("ADMIN_PASSWORD").ok() { pass } else { rand::thread_rng() .sample_iter(rand::distributions::Alphanumeric) .take(32) .map(char::from) .collect::() }; println!("User admin was created with password {{{password}}}! Save it somewhere!"); pool.add_new_user("admin", &password).await; } let router = axum::Router::new() .route("/", get(index)) .route("/login", get(login_page).post(login)) .route("/logout", get(logout)) .route("/password", get(change_password).post(change_password_post)) .route("/add_user", get(add_user_page).post(add_user)) .route("/party/:id", get(party_page)) .route("/party/:id/export", get(export_party)) .route("/party/:id/lijst", get(party_goers)) .route("/api/ticket", get(get_tickets).post(scan_card)) .route("/api/ticket/delete", post(remove_ticket)) .route("/api/party", get(get_party_by_name).post(create_party)) .route("/api/party/list", get(get_parties)) .nest( "/static", get_service(ServeDir::new("static")).handle_error(handle_error), ) .layer(Extension(pool)); axum::Server::bind(&"0.0.0.0:8080".parse().unwrap()) .serve(router.into_make_service()) .await .unwrap(); } async fn handle_error(_err: std::io::Error) -> impl IntoResponse { (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong...") }