From cba88c5def1075bd9a94cc38db6a3d9634cd2692 Mon Sep 17 00:00:00 2001 From: prescientmoon Date: Fri, 6 Sep 2024 17:31:20 +0200 Subject: [PATCH] Set up testing infrastructure --- .gitignore | 1 + Cargo.lock | 107 +++++++++-------- Cargo.toml | 3 + src/arcaea/jacket.rs | 1 + src/arcaea/play.rs | 11 +- src/commands/chart.rs | 8 +- src/commands/discord.rs | 210 ++++++++++++++++++++++++++++++++++ src/commands/mod.rs | 1 + src/commands/score.rs | 163 +++++++++++++++++--------- src/commands/stats.rs | 20 ++-- src/commands/utils/mod.rs | 2 +- src/context.rs | 53 +++++---- src/logs.rs | 30 ++--- src/main.rs | 1 + src/recognition/hyperglass.rs | 5 +- src/recognition/recognize.rs | 18 +-- src/recognition/ui.rs | 4 +- src/user.rs | 32 +++++- 18 files changed, 494 insertions(+), 176 deletions(-) create mode 100644 src/commands/discord.rs diff --git a/.gitignore b/.gitignore index b836868..c338dc9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ target backups dump.sql schema.sql +test diff --git a/Cargo.lock b/Cargo.lock index 19e477b..11bb81b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -328,7 +328,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1316,7 +1316,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.48.5", ] [[package]] @@ -1655,7 +1655,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2307,9 +2307,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] @@ -2325,9 +2325,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", @@ -2347,9 +2347,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" +checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d" dependencies = [ "serde", ] @@ -2427,7 +2427,10 @@ dependencies = [ "r2d2_sqlite", "rusqlite", "rusqlite_migration", + "serde", + "tempfile", "tokio", + "toml", ] [[package]] @@ -2611,14 +2614,15 @@ checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "tempfile" -version = "3.10.1" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64" dependencies = [ "cfg-if", "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2786,9 +2790,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.14" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335" +checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", @@ -2798,18 +2802,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.6" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d" dependencies = [ "indexmap", "serde", @@ -3217,7 +3221,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -3235,7 +3239,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -3255,18 +3268,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -3277,9 +3290,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -3289,9 +3302,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -3301,15 +3314,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -3319,9 +3332,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -3331,9 +3344,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -3343,9 +3356,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -3355,15 +3368,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index f68af09..df25aa4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,9 @@ r2d2_sqlite = "0.25.0" r2d2 = "0.8.10" rusqlite_migration = {version="1.3.0-alpha-without-tokio.1", features = ["from-directory"]} include_dir = "0.7.4" +serde = "1.0.209" +toml = "0.8.19" +tempfile = "3.12.0" [profile.dev.package."*"] opt-level = 3 diff --git a/src/arcaea/jacket.rs b/src/arcaea/jacket.rs index dc91cfd..78fa625 100644 --- a/src/arcaea/jacket.rs +++ b/src/arcaea/jacket.rs @@ -73,6 +73,7 @@ impl ImageVec { // }}} } +#[derive(Clone)] pub struct JacketCache { jackets: Vec<(u32, ImageVec)>, } diff --git a/src/arcaea/play.rs b/src/arcaea/play.rs index 206e5ba..0b608e5 100644 --- a/src/arcaea/play.rs +++ b/src/arcaea/play.rs @@ -1,4 +1,5 @@ use std::array; +use std::num::NonZeroU64; use chrono::NaiveDateTime; use chrono::Utc; @@ -6,9 +7,7 @@ use num::traits::Euclid; use num::CheckedDiv; use num::Rational32; use num::Zero; -use poise::serenity_prelude::{ - Attachment, AttachmentId, CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp, -}; +use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp}; use rusqlite::Row; use crate::arcaea::chart::{Chart, Song}; @@ -21,7 +20,7 @@ use super::score::{Score, ScoringSystem}; // {{{ Create play #[derive(Debug, Clone)] pub struct CreatePlay { - discord_attachment_id: Option, + discord_attachment_id: Option, // Scoring details score: Score, @@ -41,8 +40,8 @@ impl CreatePlay { } #[inline] - pub fn with_attachment(mut self, attachment: &Attachment) -> Self { - self.discord_attachment_id = Some(attachment.id); + pub fn with_attachment(mut self, attachment_id: NonZeroU64) -> Self { + self.discord_attachment_id = Some(attachment_id); self } diff --git a/src/commands/chart.rs b/src/commands/chart.rs index b0e1518..0e3cc6c 100644 --- a/src/commands/chart.rs +++ b/src/commands/chart.rs @@ -110,12 +110,12 @@ async fn info( /// Show the best score on a given chart #[poise::command(prefix_command, slash_command, user_cooldown = 1)] async fn best( - ctx: Context<'_>, + mut ctx: Context<'_>, #[rest] #[description = "Name of chart to show (difficulty at the end)"] name: String, ) -> Result<(), Error> { - let user = get_user!(&ctx); + let user = get_user!(&mut ctx); let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; let play = ctx @@ -164,13 +164,13 @@ async fn best( /// Show the best score on a given chart #[poise::command(prefix_command, slash_command, user_cooldown = 10)] async fn plot( - ctx: Context<'_>, + mut ctx: Context<'_>, scoring_system: Option, #[rest] #[description = "Name of chart to show (difficulty at the end)"] name: String, ) -> Result<(), Error> { - let user = get_user!(&ctx); + let user = get_user!(&mut ctx); let scoring_system = scoring_system.unwrap_or_default(); let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; diff --git a/src/commands/discord.rs b/src/commands/discord.rs new file mode 100644 index 0000000..63c112a --- /dev/null +++ b/src/commands/discord.rs @@ -0,0 +1,210 @@ +use std::num::NonZeroU64; + +use poise::serenity_prelude::{futures::future::join_all, CreateAttachment, CreateMessage}; + +use crate::{ + context::{Error, UserContext}, + timed, +}; + +// {{{ Trait +pub trait MessageContext { + /// Get the user context held by the message + fn data(&self) -> &UserContext; + fn author_id(&self) -> u64; + + /// Reply to the current message + async fn reply(&mut self, text: &str) -> Result<(), Error>; + + /// Deliver a message containing references to files. + async fn send_files( + &mut self, + attachments: impl IntoIterator, + message: CreateMessage, + ) -> Result<(), Error>; + + // {{{ Input attachments + type Attachment; + + fn is_image(attachment: &Self::Attachment) -> bool; + fn filename(attachment: &Self::Attachment) -> &str; + fn attachment_id(attachment: &Self::Attachment) -> NonZeroU64; + + /// Downloads a single file + async fn download(&self, attachment: &Self::Attachment) -> Result, Error>; + + /// Downloads every image + async fn download_images<'a>( + &self, + attachments: &'a Vec, + ) -> Result)>, Error> { + let download_tasks = attachments + .iter() + .filter(|file| Self::is_image(file)) + .map(|file| async move { (file, self.download(file).await) }); + + let downloaded = timed!("dowload_files", { join_all(download_tasks).await }); + downloaded + .into_iter() + .map(|(file, bytes)| Ok((file, bytes?))) + .collect::>() + } + // }}} +} +// }}} +// {{{ Poise implementation +impl<'a, 'b> MessageContext + for poise::Context<'a, UserContext, Box> +{ + type Attachment = poise::serenity_prelude::Attachment; + + fn data(&self) -> &UserContext { + Self::data(*self) + } + + fn author_id(&self) -> u64 { + self.author().id.get() + } + + async fn reply(&mut self, text: &str) -> Result<(), Error> { + Self::reply(*self, text).await?; + Ok(()) + } + + async fn send_files( + &mut self, + attachments: impl IntoIterator, + message: CreateMessage, + ) -> Result<(), Error> { + self.channel_id() + .send_files(self.http(), attachments, message) + .await?; + Ok(()) + } + + // {{{ Input attachments + fn attachment_id(attachment: &Self::Attachment) -> NonZeroU64 { + NonZeroU64::new(attachment.id.get()).unwrap() + } + + fn filename(attachment: &Self::Attachment) -> &str { + &attachment.filename + } + + fn is_image(attachment: &Self::Attachment) -> bool { + attachment.dimensions().is_some() + } + + async fn download(&self, attachment: &Self::Attachment) -> Result, Error> { + let res = poise::serenity_prelude::Attachment::download(attachment).await?; + Ok(res) + } + // }}} +} +// }}} +// {{{ Testing context +pub mod mock { + use std::{env, fs, path::PathBuf}; + + use super::*; + + pub struct MockContext { + pub user_id: u64, + pub data: UserContext, + pub messages: Vec<(CreateMessage, Vec)>, + } + + impl MockContext { + pub fn new(data: UserContext) -> Self { + Self { + data, + user_id: 666, + messages: vec![], + } + } + + pub fn write_to(&self, path: &PathBuf) -> Result<(), Error> { + if env::var("SHIMMERING_TEST_REGEN").unwrap_or_default() == "1" { + fs::remove_dir_all(path)?; + } + + fs::create_dir_all(path)?; + for (i, (message, attachments)) in self.messages.iter().enumerate() { + let dir = path.join(format!("{i}")); + fs::create_dir_all(&dir)?; + let message_file = dir.join("message.toml"); + + if message_file.exists() { + assert_eq!( + toml::to_string_pretty(message)?, + fs::read_to_string(message_file)? + ); + } else { + fs::write(&message_file, toml::to_string_pretty(message)?)?; + } + + for attachment in attachments { + let path = dir.join(&attachment.filename); + + if path.exists() { + assert_eq!(&attachment.data, &fs::read(path)?); + } else { + fs::write(&path, &attachment.data)?; + } + } + } + + Ok(()) + } + } + + impl MessageContext for MockContext { + fn author_id(&self) -> u64 { + self.user_id + } + + fn data(&self) -> &UserContext { + &self.data + } + + async fn reply(&mut self, text: &str) -> Result<(), Error> { + self.messages + .push((CreateMessage::new().content(text), Vec::new())); + Ok(()) + } + + async fn send_files( + &mut self, + attachments: impl IntoIterator, + message: CreateMessage, + ) -> Result<(), Error> { + self.messages + .push((message, attachments.into_iter().collect())); + Ok(()) + } + + // {{{ Input attachments + type Attachment = PathBuf; + + fn filename(attachment: &Self::Attachment) -> &str { + attachment.file_name().unwrap().to_str().unwrap() + } + + // This is a dumb implementation, but it works for testing... + fn is_image(attachment: &Self::Attachment) -> bool { + let ext = attachment.extension().unwrap(); + ext == "png" || ext == "jpg" || ext == "webp" + } + + fn attachment_id(_attachment: &Self::Attachment) -> NonZeroU64 { + NonZeroU64::new(666).unwrap() + } + + async fn download(&self, attachment: &Self::Attachment) -> Result, Error> { + let res = tokio::fs::read(attachment).await?; + Ok(res) + } + // }}} + } +} +// }}} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 053c7a6..58f8644 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,6 +1,7 @@ use crate::context::{Context, Error}; pub mod chart; +pub mod discord; pub mod score; pub mod stats; pub mod utils; diff --git a/src/commands/score.rs b/src/commands/score.rs index 8971eb7..b569494 100644 --- a/src/commands/score.rs +++ b/src/commands/score.rs @@ -1,15 +1,14 @@ -use std::time::Instant; - use crate::arcaea::play::{CreatePlay, Play}; use crate::arcaea::score::Score; use crate::context::{Context, Error}; use crate::recognition::recognize::{ImageAnalyzer, ScoreKind}; use crate::user::{discord_id_to_discord_user, User}; -use crate::{edit_reply, get_user, timed}; +use crate::{get_user, timed}; use image::DynamicImage; -use poise::serenity_prelude::futures::future::join_all; +use poise::serenity_prelude as serenity; use poise::serenity_prelude::CreateMessage; -use poise::{serenity_prelude as serenity, CreateReply}; + +use super::discord::MessageContext; // {{{ Score /// Score management @@ -24,13 +23,13 @@ pub async fn score(_ctx: Context<'_>) -> Result<(), Error> { } // }}} // {{{ Score magic -/// Identify scores from attached images. -#[poise::command(prefix_command, slash_command)] -pub async fn magic( - ctx: Context<'_>, - #[description = "Images containing scores"] files: Vec, +// {{{ Implementation +async fn magic_impl( + ctx: &mut C, + files: Vec, ) -> Result<(), Error> { - let user = get_user!(&ctx); + let user = get_user!(ctx); + let files = ctx.download_images(&files).await?; if files.len() == 0 { ctx.reply("No images found attached to message").await?; @@ -39,30 +38,9 @@ pub async fn magic( let mut embeds = Vec::with_capacity(files.len()); let mut attachments = Vec::with_capacity(files.len()); - let handle = ctx - .reply(format!("Processed 0/{} scores", files.len())) - .await?; - let mut analyzer = ImageAnalyzer::default(); - // {{{ Download files - let download_tasks = files - .iter() - .filter(|file| file.dimensions().is_some()) - .map(|file| async move { (file, file.download().await) }); - - let downloaded = timed!("dowload_files", { join_all(download_tasks).await }); - - if downloaded.len() < files.len() { - ctx.reply("One or more of the attached files are not images!") - .await?; - } - // }}} - - for (i, (file, bytes)) in downloaded.into_iter().enumerate() { - let bytes = bytes?; - - let start = Instant::now(); + for (i, (attachment, bytes)) in files.into_iter().enumerate() { // {{{ Preapare image let mut image = timed!("decode image", { image::load_from_memory(&bytes)? }); let mut grayscale_image = timed!("grayscale image", { @@ -109,7 +87,14 @@ pub async fn magic( // edit_reply!(ctx, handle, "Image {}: reading score", i + 1).await?; let score = timed!("read_score", { - analyzer.read_score(ctx.data(), Some(chart.note_count), &grayscale_image, kind)? + analyzer + .read_score(ctx.data(), Some(chart.note_count), &grayscale_image, kind) + .map_err(|err| { + format!( + "Could not read score for chart {} [{:?}]: {err}", + song.title, chart.difficulty + ) + })? }); // {{{ Build play @@ -117,7 +102,7 @@ pub async fn magic( Score::resolve_distibution_ambiguities(score, note_distribution, chart.note_count); let play = CreatePlay::new(score) - .with_attachment(file) + .with_attachment(C::attachment_id(attachment)) .with_fars(maybe_fars) .with_max_recall(max_recall) .save(&ctx.data(), &user, &chart)?; @@ -136,41 +121,113 @@ pub async fn magic( if let Err(err) = result { analyzer - .send_discord_error(ctx, &image, &file.filename, err) + .send_discord_error(ctx, &image, C::filename(&attachment), err) .await?; } - - let took = start.elapsed(); - - edit_reply!( - ctx, - handle, - "Processed {}/{} scores. Last score took {took:?} to process.", - i + 1, - files.len() - ) - .await?; } - handle.delete(ctx).await?; - if embeds.len() > 0 { - ctx.channel_id() - .send_files(ctx.http(), attachments, CreateMessage::new().embeds(embeds)) + ctx.send_files(attachments, CreateMessage::new().embeds(embeds)) .await?; } Ok(()) } // }}} +// {{{ Tests +#[cfg(test)] +mod magic_tests { + use std::{path::PathBuf, process::Command, str::FromStr}; + + use r2d2_sqlite::SqliteConnectionManager; + + use crate::{ + commands::discord::mock::MockContext, + context::{connect_db, get_shared_context}, + }; + + use super::*; + + macro_rules! with_ctx { + ($test_path:expr, $f:expr) => {{ + let mut data = (*get_shared_context().await).clone(); + let dir = tempfile::tempdir()?; + let path = dir.path().join("db.sqlite"); + println!("path {path:?}"); + data.db = connect_db(SqliteConnectionManager::file(path)); + + Command::new("scripts/import-charts.py") + .env("SHIMMERING_DATA_DIR", dir.path().to_str().unwrap()) + .output() + .unwrap(); + + let mut ctx = MockContext::new(data); + User::create_from_context(&ctx)?; + + let res: Result<(), Error> = $f(&mut ctx).await; + res?; + + ctx.write_to(&PathBuf::from_str($test_path)?)?; + Ok(()) + }}; + } + + #[tokio::test] + async fn no_pics() -> Result<(), Error> { + with_ctx!("test/commands/score/magic/no_pics", async |ctx| { + magic_impl(ctx, vec![]).await?; + Ok(()) + }) + } + + #[tokio::test] + async fn basic_pic() -> Result<(), Error> { + with_ctx!("test/commands/score/magic/single_pic", async |ctx| { + magic_impl( + ctx, + vec![PathBuf::from_str("test/screenshots/alter_ego.jpg")?], + ) + .await?; + Ok(()) + }) + } + + #[tokio::test] + async fn weird_kerning() -> Result<(), Error> { + with_ctx!("test/commands/score/magic/weird_kerning", async |ctx| { + magic_impl( + ctx, + vec![ + PathBuf::from_str("test/screenshots/antithese_74_kerning.jpg")?, + PathBuf::from_str("test/screenshots/genocider_24_kerning.jpg")?, + ], + ) + .await?; + Ok(()) + }) + } +} +// }}} + +/// Identify scores from attached images. +#[poise::command(prefix_command, slash_command)] +pub async fn magic( + mut ctx: Context<'_>, + #[description = "Images containing scores"] files: Vec, +) -> Result<(), Error> { + magic_impl(&mut ctx, files).await?; + + Ok(()) +} +// }}} // {{{ Score delete /// Delete scores, given their IDs. #[poise::command(prefix_command, slash_command)] pub async fn delete( - ctx: Context<'_>, + mut ctx: Context<'_>, #[description = "Id of score to delete"] ids: Vec, ) -> Result<(), Error> { - let user = get_user!(&ctx); + let user = get_user!(&mut ctx); if ids.len() == 0 { ctx.reply("Empty ID list provided").await?; diff --git a/src/commands/stats.rs b/src/commands/stats.rs index a5d1b8a..6143f00 100644 --- a/src/commands/stats.rs +++ b/src/commands/stats.rs @@ -43,7 +43,7 @@ pub async fn stats(_ctx: Context<'_>) -> Result<(), Error> { // }}} // {{{ Render best plays async fn best_plays( - ctx: &Context<'_>, + ctx: &mut Context<'_>, user: &User, scoring_system: ScoringSystem, grid_size: (u32, u32), @@ -403,7 +403,7 @@ async fn best_plays( ImageBuffer::from_raw(width, height, drawer.canvas.buffer.into_vec()).unwrap(), ); - debug_image_log(&image)?; + debug_image_log(&image); if image.height() > 4096 { image = image.resize(4096, 4096, image::imageops::FilterType::Nearest); @@ -426,10 +426,10 @@ async fn best_plays( // {{{ B30 /// Show the 30 best scores #[poise::command(prefix_command, slash_command, user_cooldown = 30)] -pub async fn b30(ctx: Context<'_>, scoring_system: Option) -> Result<(), Error> { - let user = get_user!(&ctx); +pub async fn b30(mut ctx: Context<'_>, scoring_system: Option) -> Result<(), Error> { + let user = get_user!(&mut ctx); best_plays( - &ctx, + &mut ctx, &user, scoring_system.unwrap_or_default(), (5, 6), @@ -440,15 +440,15 @@ pub async fn b30(ctx: Context<'_>, scoring_system: Option) -> Res #[poise::command(prefix_command, slash_command, hide_in_help, global_cooldown = 5)] pub async fn bany( - ctx: Context<'_>, + mut ctx: Context<'_>, scoring_system: Option, width: u32, height: u32, ) -> Result<(), Error> { - let user = get_user!(&ctx); + let user = get_user!(&mut ctx); assert_is_pookie!(ctx, user); best_plays( - &ctx, + &mut ctx, &user, scoring_system.unwrap_or_default(), (width, height), @@ -460,8 +460,8 @@ pub async fn bany( // {{{ Meta /// Show stats about the bot itself. #[poise::command(prefix_command, slash_command, user_cooldown = 1)] -async fn meta(ctx: Context<'_>) -> Result<(), Error> { - let user = get_user!(&ctx); +async fn meta(mut ctx: Context<'_>) -> Result<(), Error> { + let user = get_user!(&mut ctx); let conn = ctx.data().db.get()?; let song_count: usize = conn .prepare_cached("SELECT count() as count FROM songs")? diff --git a/src/commands/utils/mod.rs b/src/commands/utils/mod.rs index beda3af..9f17135 100644 --- a/src/commands/utils/mod.rs +++ b/src/commands/utils/mod.rs @@ -35,7 +35,7 @@ macro_rules! reply_errors { match $value { Ok(v) => v, Err(err) => { - $ctx.reply(format!("{err}")).await?; + crate::commands::discord::MessageContext::reply($ctx, &format!("{err}")).await?; return Ok(()); } } diff --git a/src/context.rs b/src/context.rs index bb047b7..e9b9f14 100644 --- a/src/context.rs +++ b/src/context.rs @@ -19,6 +19,7 @@ pub type Context<'a> = poise::Context<'a, UserContext, Error>; pub type DbConnection = r2d2::Pool; // Custom user data passed to all command functions +#[derive(Clone)] pub struct UserContext { pub db: DbConnection, pub song_cache: SongCache, @@ -32,36 +33,34 @@ pub struct UserContext { pub kazesawa_bold_measurements: CharMeasurements, } +pub fn connect_db(manager: SqliteConnectionManager) -> DbConnection { + timed!("create_sqlite_pool", { + Pool::new(manager.with_init(|conn| { + static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations"); + static MIGRATIONS: LazyLock = LazyLock::new(|| { + Migrations::from_directory(&MIGRATIONS_DIR).expect("Could not load migrations") + }); + + MIGRATIONS + .to_latest(conn) + .expect("Could not run migrations"); + + Ok(()) + })) + .expect("Could not open sqlite database.") + }) +} + impl UserContext { #[inline] pub async fn new() -> Result { timed!("create_context", { fs::create_dir_all(get_data_dir())?; - // {{{ Connect to database - let db = timed!("create_sqlite_pool", { - Pool::new( - SqliteConnectionManager::file(&format!( - "{}/db.sqlite", - get_data_dir().to_str().unwrap() - )) - .with_init(|conn| { - static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations"); - static MIGRATIONS: LazyLock = LazyLock::new(|| { - Migrations::from_directory(&MIGRATIONS_DIR) - .expect("Could not load migrations") - }); - - MIGRATIONS - .to_latest(conn) - .expect("Could not run migrations"); - - Ok(()) - }), - ) - .expect("Could not open sqlite database.") - }); - // }}} + let db = connect_db(SqliteConnectionManager::file(&format!( + "{}/db.sqlite", + get_data_dir().to_str().unwrap() + ))); let mut song_cache = timed!("make_song_cache", { SongCache::new(&db)? }); let jacket_cache = timed!("make_jacket_cache", { JacketCache::new(&mut song_cache)? }); @@ -93,3 +92,9 @@ impl UserContext { }) } } + +pub async fn get_shared_context() -> &'static UserContext { + static CELL: tokio::sync::OnceCell = tokio::sync::OnceCell::const_new(); + CELL.get_or_init(async || UserContext::new().await.unwrap()) + .await +} diff --git a/src/logs.rs b/src/logs.rs index 745a6ba..565fb4a 100644 --- a/src/logs.rs +++ b/src/logs.rs @@ -10,7 +10,7 @@ use std::{env, ops::Deref, path::PathBuf, sync::OnceLock, time::Instant}; use image::{DynamicImage, EncodableLayout, ImageBuffer, PixelWithColorType}; -use crate::{assets::get_path, context::Error}; +use crate::assets::get_path; #[inline] fn should_save_debug_images() -> bool { @@ -31,30 +31,30 @@ fn get_startup_time() -> Instant { } #[inline] -pub fn debug_image_log(image: &DynamicImage) -> Result<(), Error> { +pub fn debug_image_log(image: &DynamicImage) { if should_save_debug_images() { - image.save(get_log_dir().join(format!( - "{:0>15}.png", - get_startup_time().elapsed().as_nanos() - )))?; + image + .save(get_log_dir().join(format!( + "{:0>15}.png", + get_startup_time().elapsed().as_nanos() + ))) + .unwrap(); } - - Ok(()) } #[inline] -pub fn debug_image_buffer_log(image: &ImageBuffer) -> Result<(), Error> +pub fn debug_image_buffer_log(image: &ImageBuffer) where P: PixelWithColorType, [P::Subpixel]: EncodableLayout, C: Deref, { if should_save_debug_images() { - image.save(get_log_dir().join(format!( - "./logs/{:0>15}.png", - get_startup_time().elapsed().as_nanos() - )))?; + image + .save(get_log_dir().join(format!( + "{:0>15}.png", + get_startup_time().elapsed().as_nanos() + ))) + .unwrap(); } - - Ok(()) } diff --git a/src/main.rs b/src/main.rs index 28ed47e..9d77321 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ #![feature(try_blocks)] #![feature(thread_local)] #![feature(generic_arg_infer)] +#![feature(lazy_cell_consume)] mod arcaea; mod assets; diff --git a/src/recognition/hyperglass.rs b/src/recognition/hyperglass.rs index 6ac0ee3..85edfe4 100644 --- a/src/recognition/hyperglass.rs +++ b/src/recognition/hyperglass.rs @@ -164,7 +164,7 @@ impl ComponentsWithBounds { binarisation_threshold, ThresholdType::Binary, ); - debug_image_buffer_log(&image)?; + debug_image_buffer_log(&image); let background = Luma([u8::MAX]); let components = connected_components(&image, Connectivity::Eight, background); @@ -223,6 +223,7 @@ impl ComponentsWithBounds { } // }}} // {{{ Char measurements +#[derive(Clone)] pub struct CharMeasurements { chars: Vec<(char, ComponentVec)>, @@ -258,7 +259,7 @@ impl CharMeasurements { .ok_or_else(|| "Failed to turn buffer into canvas")?; let image = DynamicImage::ImageRgb8(buffer); - debug_image_log(&image)?; + debug_image_log(&image); let components = ComponentsWithBounds::from_image(&image, 100)?; diff --git a/src/recognition/recognize.rs b/src/recognition/recognize.rs index 83020ef..9109f31 100644 --- a/src/recognition/recognize.rs +++ b/src/recognition/recognize.rs @@ -10,7 +10,8 @@ use crate::arcaea::chart::{Chart, Difficulty, Song, DIFFICULTY_MENU_PIXEL_COLORS use crate::arcaea::jacket::IMAGE_VEC_DIM; use crate::arcaea::score::Score; use crate::bitmap::{Color, Rect}; -use crate::context::{Context, Error, UserContext}; +use crate::commands::discord::MessageContext; +use crate::context::{Error, UserContext}; use crate::levenshtein::edit_distance; use crate::logs::debug_image_log; use crate::recognition::fuzzy_song_name::guess_chart_name; @@ -61,7 +62,7 @@ impl ImageAnalyzer { self.last_rect = Some((ui_rect, rect)); let result = self.crop(image, rect); - debug_image_log(&result)?; + debug_image_log(&result); Ok(result) } @@ -80,7 +81,7 @@ impl ImageAnalyzer { let result = self.crop(image, rect); let result = result.resize(size.0, size.1, FilterType::Nearest); - debug_image_log(&result)?; + debug_image_log(&result); Ok(result) } @@ -88,7 +89,7 @@ impl ImageAnalyzer { // {{{ Error handling pub async fn send_discord_error( &mut self, - ctx: Context<'_>, + ctx: &mut impl MessageContext, image: &DynamicImage, filename: &str, err: impl Display, @@ -112,14 +113,12 @@ impl ImageAnalyzer { )); let msg = CreateMessage::default().embed(embed); - ctx.channel_id() - .send_files(ctx.http(), [error_attachement], msg) - .await?; + ctx.send_files([error_attachement], msg).await?; } else { embed = embed.title("An error occurred"); let msg = CreateMessage::default().embed(embed); - ctx.channel_id().send_files(ctx.http(), [], msg).await?; + ctx.send_files([], msg).await?; } Ok(()) @@ -347,7 +346,8 @@ impl ImageAnalyzer { out[i] = ctx .kazesawa_bold_measurements .recognise(&image, "0123456789", Some(30))? - .parse()?; + .parse() + .unwrap_or(100000); // This will get discarded as making no sense } println!("Ditribution {out:?}"); diff --git a/src/recognition/ui.rs b/src/recognition/ui.rs index 8886820..2d279b4 100644 --- a/src/recognition/ui.rs +++ b/src/recognition/ui.rs @@ -60,7 +60,7 @@ impl UIMeasurementRect { pub const UI_RECT_COUNT: usize = 15; // }}} // {{{ Measurement -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UIMeasurement { dimensions: [u32; 2], datapoints: [u32; UI_RECT_COUNT * 4], @@ -87,7 +87,7 @@ impl UIMeasurement { } // }}} // {{{ Measurements -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UIMeasurements { pub measurements: Vec, } diff --git a/src/user.rs b/src/user.rs index 111b127..a32297f 100644 --- a/src/user.rs +++ b/src/user.rs @@ -3,7 +3,10 @@ use std::str::FromStr; use poise::serenity_prelude::UserId; use rusqlite::Row; -use crate::context::{Context, Error, UserContext}; +use crate::{ + commands::discord::MessageContext, + context::{Context, Error, UserContext}, +}; #[derive(Debug, Clone)] pub struct User { @@ -22,8 +25,31 @@ impl User { }) } - pub fn from_context(ctx: &Context<'_>) -> Result { - let id = ctx.author().id.get().to_string(); + pub fn create_from_context(ctx: &impl MessageContext) -> Result { + let discord_id = ctx.author_id().to_string(); + let user_id: u32 = ctx + .data() + .db + .get()? + .prepare_cached( + " + INSERT INTO users(discord_id) VALUES (?) + RETURNING id + ", + )? + .query_map([&discord_id], |row| row.get("id"))? + .next() + .ok_or_else(|| "Failed to create user")??; + + Ok(Self { + discord_id, + id: user_id, + is_pookie: false, + }) + } + + pub fn from_context(ctx: &impl MessageContext) -> Result { + let id = ctx.author_id(); let user = ctx .data() .db