1
Fork 0

Set up testing infrastructure

This commit is contained in:
prescientmoon 2024-09-06 17:31:20 +02:00
parent e74ddfd106
commit cba88c5def
Signed by: prescientmoon
SSH key fingerprint: SHA256:WFp/cO76nbarETAoQcQXuV+0h7XJsEsOCI0UsyPIy6U
18 changed files with 494 additions and 176 deletions

1
.gitignore vendored
View file

@ -11,3 +11,4 @@ target
backups
dump.sql
schema.sql
test

107
Cargo.lock generated
View file

@ -328,7 +328,7 @@ dependencies = [
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets 0.52.5",
"windows-targets 0.52.6",
]
[[package]]
@ -1316,7 +1316,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19"
dependencies = [
"cfg-if",
"windows-targets 0.52.5",
"windows-targets 0.48.5",
]
[[package]]
@ -1655,7 +1655,7 @@ dependencies = [
"libc",
"redox_syscall",
"smallvec",
"windows-targets 0.52.5",
"windows-targets 0.52.6",
]
[[package]]
@ -2307,9 +2307,9 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.204"
version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12"
checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
dependencies = [
"serde_derive",
]
@ -2325,9 +2325,9 @@ dependencies = [
[[package]]
name = "serde_derive"
version = "1.0.204"
version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
dependencies = [
"proc-macro2",
"quote",
@ -2347,9 +2347,9 @@ dependencies = [
[[package]]
name = "serde_spanned"
version = "0.6.6"
version = "0.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0"
checksum = "eb5b1b31579f3811bf615c144393417496f152e12ac8b7663bf664f4a815306d"
dependencies = [
"serde",
]
@ -2427,7 +2427,10 @@ dependencies = [
"r2d2_sqlite",
"rusqlite",
"rusqlite_migration",
"serde",
"tempfile",
"tokio",
"toml",
]
[[package]]
@ -2611,14 +2614,15 @@ checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f"
[[package]]
name = "tempfile"
version = "3.10.1"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1"
checksum = "04cbcdd0c794ebb0d4cf35e88edd2f7d2c4c3e9a5a6dab322839b321c6a87a64"
dependencies = [
"cfg-if",
"fastrand",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -2786,9 +2790,9 @@ dependencies = [
[[package]]
name = "toml"
version = "0.8.14"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f49eb2ab21d2f26bd6db7bf383edc527a7ebaee412d17af4d40fdccd442f335"
checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e"
dependencies = [
"serde",
"serde_spanned",
@ -2798,18 +2802,18 @@ dependencies = [
[[package]]
name = "toml_datetime"
version = "0.6.6"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.14"
version = "0.22.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38"
checksum = "583c44c02ad26b0c3f3066fe629275e50627026c51ac2e595cca4c230ce1ce1d"
dependencies = [
"indexmap",
"serde",
@ -3217,7 +3221,7 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9"
dependencies = [
"windows-targets 0.52.5",
"windows-targets 0.52.6",
]
[[package]]
@ -3235,7 +3239,16 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.5",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
@ -3255,18 +3268,18 @@ dependencies = [
[[package]]
name = "windows-targets"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm 0.52.5",
"windows_aarch64_msvc 0.52.5",
"windows_i686_gnu 0.52.5",
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_msvc 0.52.5",
"windows_x86_64_gnu 0.52.5",
"windows_x86_64_gnullvm 0.52.5",
"windows_x86_64_msvc 0.52.5",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
@ -3277,9 +3290,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
@ -3289,9 +3302,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
@ -3301,15 +3314,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
@ -3319,9 +3332,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
@ -3331,9 +3344,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
@ -3343,9 +3356,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
@ -3355,15 +3368,15 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.5"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.6.13"
version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1"
checksum = "68a9bda4691f099d435ad181000724da8e5899daa10713c2d432552b9ccd3a6f"
dependencies = [
"memchr",
]

View file

@ -18,6 +18,9 @@ r2d2_sqlite = "0.25.0"
r2d2 = "0.8.10"
rusqlite_migration = {version="1.3.0-alpha-without-tokio.1", features = ["from-directory"]}
include_dir = "0.7.4"
serde = "1.0.209"
toml = "0.8.19"
tempfile = "3.12.0"
[profile.dev.package."*"]
opt-level = 3

View file

@ -73,6 +73,7 @@ impl ImageVec {
// }}}
}
#[derive(Clone)]
pub struct JacketCache {
jackets: Vec<(u32, ImageVec)>,
}

View file

@ -1,4 +1,5 @@
use std::array;
use std::num::NonZeroU64;
use chrono::NaiveDateTime;
use chrono::Utc;
@ -6,9 +7,7 @@ use num::traits::Euclid;
use num::CheckedDiv;
use num::Rational32;
use num::Zero;
use poise::serenity_prelude::{
Attachment, AttachmentId, CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp,
};
use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp};
use rusqlite::Row;
use crate::arcaea::chart::{Chart, Song};
@ -21,7 +20,7 @@ use super::score::{Score, ScoringSystem};
// {{{ Create play
#[derive(Debug, Clone)]
pub struct CreatePlay {
discord_attachment_id: Option<AttachmentId>,
discord_attachment_id: Option<NonZeroU64>,
// Scoring details
score: Score,
@ -41,8 +40,8 @@ impl CreatePlay {
}
#[inline]
pub fn with_attachment(mut self, attachment: &Attachment) -> Self {
self.discord_attachment_id = Some(attachment.id);
pub fn with_attachment(mut self, attachment_id: NonZeroU64) -> Self {
self.discord_attachment_id = Some(attachment_id);
self
}

View file

@ -110,12 +110,12 @@ async fn info(
/// Show the best score on a given chart
#[poise::command(prefix_command, slash_command, user_cooldown = 1)]
async fn best(
ctx: Context<'_>,
mut ctx: Context<'_>,
#[rest]
#[description = "Name of chart to show (difficulty at the end)"]
name: String,
) -> Result<(), Error> {
let user = get_user!(&ctx);
let user = get_user!(&mut ctx);
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?;
let play = ctx
@ -164,13 +164,13 @@ async fn best(
/// Show the best score on a given chart
#[poise::command(prefix_command, slash_command, user_cooldown = 10)]
async fn plot(
ctx: Context<'_>,
mut ctx: Context<'_>,
scoring_system: Option<ScoringSystem>,
#[rest]
#[description = "Name of chart to show (difficulty at the end)"]
name: String,
) -> Result<(), Error> {
let user = get_user!(&ctx);
let user = get_user!(&mut ctx);
let scoring_system = scoring_system.unwrap_or_default();
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?;

210
src/commands/discord.rs Normal file
View file

@ -0,0 +1,210 @@
use std::num::NonZeroU64;
use poise::serenity_prelude::{futures::future::join_all, CreateAttachment, CreateMessage};
use crate::{
context::{Error, UserContext},
timed,
};
// {{{ Trait
pub trait MessageContext {
/// Get the user context held by the message
fn data(&self) -> &UserContext;
fn author_id(&self) -> u64;
/// Reply to the current message
async fn reply(&mut self, text: &str) -> Result<(), Error>;
/// Deliver a message containing references to files.
async fn send_files(
&mut self,
attachments: impl IntoIterator<Item = CreateAttachment>,
message: CreateMessage,
) -> Result<(), Error>;
// {{{ Input attachments
type Attachment;
fn is_image(attachment: &Self::Attachment) -> bool;
fn filename(attachment: &Self::Attachment) -> &str;
fn attachment_id(attachment: &Self::Attachment) -> NonZeroU64;
/// Downloads a single file
async fn download(&self, attachment: &Self::Attachment) -> Result<Vec<u8>, Error>;
/// Downloads every image
async fn download_images<'a>(
&self,
attachments: &'a Vec<Self::Attachment>,
) -> Result<Vec<(&'a Self::Attachment, Vec<u8>)>, Error> {
let download_tasks = attachments
.iter()
.filter(|file| Self::is_image(file))
.map(|file| async move { (file, self.download(file).await) });
let downloaded = timed!("dowload_files", { join_all(download_tasks).await });
downloaded
.into_iter()
.map(|(file, bytes)| Ok((file, bytes?)))
.collect::<Result<_, Error>>()
}
// }}}
}
// }}}
// {{{ Poise implementation
impl<'a, 'b> MessageContext
for poise::Context<'a, UserContext, Box<dyn std::error::Error + Send + Sync + 'b>>
{
type Attachment = poise::serenity_prelude::Attachment;
fn data(&self) -> &UserContext {
Self::data(*self)
}
fn author_id(&self) -> u64 {
self.author().id.get()
}
async fn reply(&mut self, text: &str) -> Result<(), Error> {
Self::reply(*self, text).await?;
Ok(())
}
async fn send_files(
&mut self,
attachments: impl IntoIterator<Item = CreateAttachment>,
message: CreateMessage,
) -> Result<(), Error> {
self.channel_id()
.send_files(self.http(), attachments, message)
.await?;
Ok(())
}
// {{{ Input attachments
fn attachment_id(attachment: &Self::Attachment) -> NonZeroU64 {
NonZeroU64::new(attachment.id.get()).unwrap()
}
fn filename(attachment: &Self::Attachment) -> &str {
&attachment.filename
}
fn is_image(attachment: &Self::Attachment) -> bool {
attachment.dimensions().is_some()
}
async fn download(&self, attachment: &Self::Attachment) -> Result<Vec<u8>, Error> {
let res = poise::serenity_prelude::Attachment::download(attachment).await?;
Ok(res)
}
// }}}
}
// }}}
// {{{ Testing context
pub mod mock {
use std::{env, fs, path::PathBuf};
use super::*;
pub struct MockContext {
pub user_id: u64,
pub data: UserContext,
pub messages: Vec<(CreateMessage, Vec<CreateAttachment>)>,
}
impl MockContext {
pub fn new(data: UserContext) -> Self {
Self {
data,
user_id: 666,
messages: vec![],
}
}
pub fn write_to(&self, path: &PathBuf) -> Result<(), Error> {
if env::var("SHIMMERING_TEST_REGEN").unwrap_or_default() == "1" {
fs::remove_dir_all(path)?;
}
fs::create_dir_all(path)?;
for (i, (message, attachments)) in self.messages.iter().enumerate() {
let dir = path.join(format!("{i}"));
fs::create_dir_all(&dir)?;
let message_file = dir.join("message.toml");
if message_file.exists() {
assert_eq!(
toml::to_string_pretty(message)?,
fs::read_to_string(message_file)?
);
} else {
fs::write(&message_file, toml::to_string_pretty(message)?)?;
}
for attachment in attachments {
let path = dir.join(&attachment.filename);
if path.exists() {
assert_eq!(&attachment.data, &fs::read(path)?);
} else {
fs::write(&path, &attachment.data)?;
}
}
}
Ok(())
}
}
impl MessageContext for MockContext {
fn author_id(&self) -> u64 {
self.user_id
}
fn data(&self) -> &UserContext {
&self.data
}
async fn reply(&mut self, text: &str) -> Result<(), Error> {
self.messages
.push((CreateMessage::new().content(text), Vec::new()));
Ok(())
}
async fn send_files(
&mut self,
attachments: impl IntoIterator<Item = CreateAttachment>,
message: CreateMessage,
) -> Result<(), Error> {
self.messages
.push((message, attachments.into_iter().collect()));
Ok(())
}
// {{{ Input attachments
type Attachment = PathBuf;
fn filename(attachment: &Self::Attachment) -> &str {
attachment.file_name().unwrap().to_str().unwrap()
}
// This is a dumb implementation, but it works for testing...
fn is_image(attachment: &Self::Attachment) -> bool {
let ext = attachment.extension().unwrap();
ext == "png" || ext == "jpg" || ext == "webp"
}
fn attachment_id(_attachment: &Self::Attachment) -> NonZeroU64 {
NonZeroU64::new(666).unwrap()
}
async fn download(&self, attachment: &Self::Attachment) -> Result<Vec<u8>, Error> {
let res = tokio::fs::read(attachment).await?;
Ok(res)
}
// }}}
}
}
// }}}

View file

@ -1,6 +1,7 @@
use crate::context::{Context, Error};
pub mod chart;
pub mod discord;
pub mod score;
pub mod stats;
pub mod utils;

View file

@ -1,15 +1,14 @@
use std::time::Instant;
use crate::arcaea::play::{CreatePlay, Play};
use crate::arcaea::score::Score;
use crate::context::{Context, Error};
use crate::recognition::recognize::{ImageAnalyzer, ScoreKind};
use crate::user::{discord_id_to_discord_user, User};
use crate::{edit_reply, get_user, timed};
use crate::{get_user, timed};
use image::DynamicImage;
use poise::serenity_prelude::futures::future::join_all;
use poise::serenity_prelude as serenity;
use poise::serenity_prelude::CreateMessage;
use poise::{serenity_prelude as serenity, CreateReply};
use super::discord::MessageContext;
// {{{ Score
/// Score management
@ -24,13 +23,13 @@ pub async fn score(_ctx: Context<'_>) -> Result<(), Error> {
}
// }}}
// {{{ Score magic
/// Identify scores from attached images.
#[poise::command(prefix_command, slash_command)]
pub async fn magic(
ctx: Context<'_>,
#[description = "Images containing scores"] files: Vec<serenity::Attachment>,
// {{{ Implementation
async fn magic_impl<C: MessageContext>(
ctx: &mut C,
files: Vec<C::Attachment>,
) -> Result<(), Error> {
let user = get_user!(&ctx);
let user = get_user!(ctx);
let files = ctx.download_images(&files).await?;
if files.len() == 0 {
ctx.reply("No images found attached to message").await?;
@ -39,30 +38,9 @@ pub async fn magic(
let mut embeds = Vec::with_capacity(files.len());
let mut attachments = Vec::with_capacity(files.len());
let handle = ctx
.reply(format!("Processed 0/{} scores", files.len()))
.await?;
let mut analyzer = ImageAnalyzer::default();
// {{{ Download files
let download_tasks = files
.iter()
.filter(|file| file.dimensions().is_some())
.map(|file| async move { (file, file.download().await) });
let downloaded = timed!("dowload_files", { join_all(download_tasks).await });
if downloaded.len() < files.len() {
ctx.reply("One or more of the attached files are not images!")
.await?;
}
// }}}
for (i, (file, bytes)) in downloaded.into_iter().enumerate() {
let bytes = bytes?;
let start = Instant::now();
for (i, (attachment, bytes)) in files.into_iter().enumerate() {
// {{{ Preapare image
let mut image = timed!("decode image", { image::load_from_memory(&bytes)? });
let mut grayscale_image = timed!("grayscale image", {
@ -109,7 +87,14 @@ pub async fn magic(
// edit_reply!(ctx, handle, "Image {}: reading score", i + 1).await?;
let score = timed!("read_score", {
analyzer.read_score(ctx.data(), Some(chart.note_count), &grayscale_image, kind)?
analyzer
.read_score(ctx.data(), Some(chart.note_count), &grayscale_image, kind)
.map_err(|err| {
format!(
"Could not read score for chart {} [{:?}]: {err}",
song.title, chart.difficulty
)
})?
});
// {{{ Build play
@ -117,7 +102,7 @@ pub async fn magic(
Score::resolve_distibution_ambiguities(score, note_distribution, chart.note_count);
let play = CreatePlay::new(score)
.with_attachment(file)
.with_attachment(C::attachment_id(attachment))
.with_fars(maybe_fars)
.with_max_recall(max_recall)
.save(&ctx.data(), &user, &chart)?;
@ -136,41 +121,113 @@ pub async fn magic(
if let Err(err) = result {
analyzer
.send_discord_error(ctx, &image, &file.filename, err)
.send_discord_error(ctx, &image, C::filename(&attachment), err)
.await?;
}
let took = start.elapsed();
edit_reply!(
ctx,
handle,
"Processed {}/{} scores. Last score took {took:?} to process.",
i + 1,
files.len()
)
.await?;
}
handle.delete(ctx).await?;
if embeds.len() > 0 {
ctx.channel_id()
.send_files(ctx.http(), attachments, CreateMessage::new().embeds(embeds))
ctx.send_files(attachments, CreateMessage::new().embeds(embeds))
.await?;
}
Ok(())
}
// }}}
// {{{ Tests
#[cfg(test)]
mod magic_tests {
use std::{path::PathBuf, process::Command, str::FromStr};
use r2d2_sqlite::SqliteConnectionManager;
use crate::{
commands::discord::mock::MockContext,
context::{connect_db, get_shared_context},
};
use super::*;
macro_rules! with_ctx {
($test_path:expr, $f:expr) => {{
let mut data = (*get_shared_context().await).clone();
let dir = tempfile::tempdir()?;
let path = dir.path().join("db.sqlite");
println!("path {path:?}");
data.db = connect_db(SqliteConnectionManager::file(path));
Command::new("scripts/import-charts.py")
.env("SHIMMERING_DATA_DIR", dir.path().to_str().unwrap())
.output()
.unwrap();
let mut ctx = MockContext::new(data);
User::create_from_context(&ctx)?;
let res: Result<(), Error> = $f(&mut ctx).await;
res?;
ctx.write_to(&PathBuf::from_str($test_path)?)?;
Ok(())
}};
}
#[tokio::test]
async fn no_pics() -> Result<(), Error> {
with_ctx!("test/commands/score/magic/no_pics", async |ctx| {
magic_impl(ctx, vec![]).await?;
Ok(())
})
}
#[tokio::test]
async fn basic_pic() -> Result<(), Error> {
with_ctx!("test/commands/score/magic/single_pic", async |ctx| {
magic_impl(
ctx,
vec![PathBuf::from_str("test/screenshots/alter_ego.jpg")?],
)
.await?;
Ok(())
})
}
#[tokio::test]
async fn weird_kerning() -> Result<(), Error> {
with_ctx!("test/commands/score/magic/weird_kerning", async |ctx| {
magic_impl(
ctx,
vec![
PathBuf::from_str("test/screenshots/antithese_74_kerning.jpg")?,
PathBuf::from_str("test/screenshots/genocider_24_kerning.jpg")?,
],
)
.await?;
Ok(())
})
}
}
// }}}
/// Identify scores from attached images.
#[poise::command(prefix_command, slash_command)]
pub async fn magic(
mut ctx: Context<'_>,
#[description = "Images containing scores"] files: Vec<serenity::Attachment>,
) -> Result<(), Error> {
magic_impl(&mut ctx, files).await?;
Ok(())
}
// }}}
// {{{ Score delete
/// Delete scores, given their IDs.
#[poise::command(prefix_command, slash_command)]
pub async fn delete(
ctx: Context<'_>,
mut ctx: Context<'_>,
#[description = "Id of score to delete"] ids: Vec<u32>,
) -> Result<(), Error> {
let user = get_user!(&ctx);
let user = get_user!(&mut ctx);
if ids.len() == 0 {
ctx.reply("Empty ID list provided").await?;

View file

@ -43,7 +43,7 @@ pub async fn stats(_ctx: Context<'_>) -> Result<(), Error> {
// }}}
// {{{ Render best plays
async fn best_plays(
ctx: &Context<'_>,
ctx: &mut Context<'_>,
user: &User,
scoring_system: ScoringSystem,
grid_size: (u32, u32),
@ -403,7 +403,7 @@ async fn best_plays(
ImageBuffer::from_raw(width, height, drawer.canvas.buffer.into_vec()).unwrap(),
);
debug_image_log(&image)?;
debug_image_log(&image);
if image.height() > 4096 {
image = image.resize(4096, 4096, image::imageops::FilterType::Nearest);
@ -426,10 +426,10 @@ async fn best_plays(
// {{{ B30
/// Show the 30 best scores
#[poise::command(prefix_command, slash_command, user_cooldown = 30)]
pub async fn b30(ctx: Context<'_>, scoring_system: Option<ScoringSystem>) -> Result<(), Error> {
let user = get_user!(&ctx);
pub async fn b30(mut ctx: Context<'_>, scoring_system: Option<ScoringSystem>) -> Result<(), Error> {
let user = get_user!(&mut ctx);
best_plays(
&ctx,
&mut ctx,
&user,
scoring_system.unwrap_or_default(),
(5, 6),
@ -440,15 +440,15 @@ pub async fn b30(ctx: Context<'_>, scoring_system: Option<ScoringSystem>) -> Res
#[poise::command(prefix_command, slash_command, hide_in_help, global_cooldown = 5)]
pub async fn bany(
ctx: Context<'_>,
mut ctx: Context<'_>,
scoring_system: Option<ScoringSystem>,
width: u32,
height: u32,
) -> Result<(), Error> {
let user = get_user!(&ctx);
let user = get_user!(&mut ctx);
assert_is_pookie!(ctx, user);
best_plays(
&ctx,
&mut ctx,
&user,
scoring_system.unwrap_or_default(),
(width, height),
@ -460,8 +460,8 @@ pub async fn bany(
// {{{ Meta
/// Show stats about the bot itself.
#[poise::command(prefix_command, slash_command, user_cooldown = 1)]
async fn meta(ctx: Context<'_>) -> Result<(), Error> {
let user = get_user!(&ctx);
async fn meta(mut ctx: Context<'_>) -> Result<(), Error> {
let user = get_user!(&mut ctx);
let conn = ctx.data().db.get()?;
let song_count: usize = conn
.prepare_cached("SELECT count() as count FROM songs")?

View file

@ -35,7 +35,7 @@ macro_rules! reply_errors {
match $value {
Ok(v) => v,
Err(err) => {
$ctx.reply(format!("{err}")).await?;
crate::commands::discord::MessageContext::reply($ctx, &format!("{err}")).await?;
return Ok(());
}
}

View file

@ -19,6 +19,7 @@ pub type Context<'a> = poise::Context<'a, UserContext, Error>;
pub type DbConnection = r2d2::Pool<SqliteConnectionManager>;
// Custom user data passed to all command functions
#[derive(Clone)]
pub struct UserContext {
pub db: DbConnection,
pub song_cache: SongCache,
@ -32,36 +33,34 @@ pub struct UserContext {
pub kazesawa_bold_measurements: CharMeasurements,
}
pub fn connect_db(manager: SqliteConnectionManager) -> DbConnection {
timed!("create_sqlite_pool", {
Pool::new(manager.with_init(|conn| {
static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
static MIGRATIONS: LazyLock<Migrations> = LazyLock::new(|| {
Migrations::from_directory(&MIGRATIONS_DIR).expect("Could not load migrations")
});
MIGRATIONS
.to_latest(conn)
.expect("Could not run migrations");
Ok(())
}))
.expect("Could not open sqlite database.")
})
}
impl UserContext {
#[inline]
pub async fn new() -> Result<Self, Error> {
timed!("create_context", {
fs::create_dir_all(get_data_dir())?;
// {{{ Connect to database
let db = timed!("create_sqlite_pool", {
Pool::new(
SqliteConnectionManager::file(&format!(
"{}/db.sqlite",
get_data_dir().to_str().unwrap()
))
.with_init(|conn| {
static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
static MIGRATIONS: LazyLock<Migrations> = LazyLock::new(|| {
Migrations::from_directory(&MIGRATIONS_DIR)
.expect("Could not load migrations")
});
MIGRATIONS
.to_latest(conn)
.expect("Could not run migrations");
Ok(())
}),
)
.expect("Could not open sqlite database.")
});
// }}}
let db = connect_db(SqliteConnectionManager::file(&format!(
"{}/db.sqlite",
get_data_dir().to_str().unwrap()
)));
let mut song_cache = timed!("make_song_cache", { SongCache::new(&db)? });
let jacket_cache = timed!("make_jacket_cache", { JacketCache::new(&mut song_cache)? });
@ -93,3 +92,9 @@ impl UserContext {
})
}
}
pub async fn get_shared_context() -> &'static UserContext {
static CELL: tokio::sync::OnceCell<UserContext> = tokio::sync::OnceCell::const_new();
CELL.get_or_init(async || UserContext::new().await.unwrap())
.await
}

View file

@ -10,7 +10,7 @@ use std::{env, ops::Deref, path::PathBuf, sync::OnceLock, time::Instant};
use image::{DynamicImage, EncodableLayout, ImageBuffer, PixelWithColorType};
use crate::{assets::get_path, context::Error};
use crate::assets::get_path;
#[inline]
fn should_save_debug_images() -> bool {
@ -31,30 +31,30 @@ fn get_startup_time() -> Instant {
}
#[inline]
pub fn debug_image_log(image: &DynamicImage) -> Result<(), Error> {
pub fn debug_image_log(image: &DynamicImage) {
if should_save_debug_images() {
image.save(get_log_dir().join(format!(
"{:0>15}.png",
get_startup_time().elapsed().as_nanos()
)))?;
image
.save(get_log_dir().join(format!(
"{:0>15}.png",
get_startup_time().elapsed().as_nanos()
)))
.unwrap();
}
Ok(())
}
#[inline]
pub fn debug_image_buffer_log<P, C>(image: &ImageBuffer<P, C>) -> Result<(), Error>
pub fn debug_image_buffer_log<P, C>(image: &ImageBuffer<P, C>)
where
P: PixelWithColorType,
[P::Subpixel]: EncodableLayout,
C: Deref<Target = [P::Subpixel]>,
{
if should_save_debug_images() {
image.save(get_log_dir().join(format!(
"./logs/{:0>15}.png",
get_startup_time().elapsed().as_nanos()
)))?;
image
.save(get_log_dir().join(format!(
"{:0>15}.png",
get_startup_time().elapsed().as_nanos()
)))
.unwrap();
}
Ok(())
}

View file

@ -6,6 +6,7 @@
#![feature(try_blocks)]
#![feature(thread_local)]
#![feature(generic_arg_infer)]
#![feature(lazy_cell_consume)]
mod arcaea;
mod assets;

View file

@ -164,7 +164,7 @@ impl ComponentsWithBounds {
binarisation_threshold,
ThresholdType::Binary,
);
debug_image_buffer_log(&image)?;
debug_image_buffer_log(&image);
let background = Luma([u8::MAX]);
let components = connected_components(&image, Connectivity::Eight, background);
@ -223,6 +223,7 @@ impl ComponentsWithBounds {
}
// }}}
// {{{ Char measurements
#[derive(Clone)]
pub struct CharMeasurements {
chars: Vec<(char, ComponentVec)>,
@ -258,7 +259,7 @@ impl CharMeasurements {
.ok_or_else(|| "Failed to turn buffer into canvas")?;
let image = DynamicImage::ImageRgb8(buffer);
debug_image_log(&image)?;
debug_image_log(&image);
let components = ComponentsWithBounds::from_image(&image, 100)?;

View file

@ -10,7 +10,8 @@ use crate::arcaea::chart::{Chart, Difficulty, Song, DIFFICULTY_MENU_PIXEL_COLORS
use crate::arcaea::jacket::IMAGE_VEC_DIM;
use crate::arcaea::score::Score;
use crate::bitmap::{Color, Rect};
use crate::context::{Context, Error, UserContext};
use crate::commands::discord::MessageContext;
use crate::context::{Error, UserContext};
use crate::levenshtein::edit_distance;
use crate::logs::debug_image_log;
use crate::recognition::fuzzy_song_name::guess_chart_name;
@ -61,7 +62,7 @@ impl ImageAnalyzer {
self.last_rect = Some((ui_rect, rect));
let result = self.crop(image, rect);
debug_image_log(&result)?;
debug_image_log(&result);
Ok(result)
}
@ -80,7 +81,7 @@ impl ImageAnalyzer {
let result = self.crop(image, rect);
let result = result.resize(size.0, size.1, FilterType::Nearest);
debug_image_log(&result)?;
debug_image_log(&result);
Ok(result)
}
@ -88,7 +89,7 @@ impl ImageAnalyzer {
// {{{ Error handling
pub async fn send_discord_error(
&mut self,
ctx: Context<'_>,
ctx: &mut impl MessageContext,
image: &DynamicImage,
filename: &str,
err: impl Display,
@ -112,14 +113,12 @@ impl ImageAnalyzer {
));
let msg = CreateMessage::default().embed(embed);
ctx.channel_id()
.send_files(ctx.http(), [error_attachement], msg)
.await?;
ctx.send_files([error_attachement], msg).await?;
} else {
embed = embed.title("An error occurred");
let msg = CreateMessage::default().embed(embed);
ctx.channel_id().send_files(ctx.http(), [], msg).await?;
ctx.send_files([], msg).await?;
}
Ok(())
@ -347,7 +346,8 @@ impl ImageAnalyzer {
out[i] = ctx
.kazesawa_bold_measurements
.recognise(&image, "0123456789", Some(30))?
.parse()?;
.parse()
.unwrap_or(100000); // This will get discarded as making no sense
}
println!("Ditribution {out:?}");

View file

@ -60,7 +60,7 @@ impl UIMeasurementRect {
pub const UI_RECT_COUNT: usize = 15;
// }}}
// {{{ Measurement
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UIMeasurement {
dimensions: [u32; 2],
datapoints: [u32; UI_RECT_COUNT * 4],
@ -87,7 +87,7 @@ impl UIMeasurement {
}
// }}}
// {{{ Measurements
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UIMeasurements {
pub measurements: Vec<UIMeasurement>,
}

View file

@ -3,7 +3,10 @@ use std::str::FromStr;
use poise::serenity_prelude::UserId;
use rusqlite::Row;
use crate::context::{Context, Error, UserContext};
use crate::{
commands::discord::MessageContext,
context::{Context, Error, UserContext},
};
#[derive(Debug, Clone)]
pub struct User {
@ -22,8 +25,31 @@ impl User {
})
}
pub fn from_context(ctx: &Context<'_>) -> Result<Self, Error> {
let id = ctx.author().id.get().to_string();
pub fn create_from_context(ctx: &impl MessageContext) -> Result<Self, Error> {
let discord_id = ctx.author_id().to_string();
let user_id: u32 = ctx
.data()
.db
.get()?
.prepare_cached(
"
INSERT INTO users(discord_id) VALUES (?)
RETURNING id
",
)?
.query_map([&discord_id], |row| row.get("id"))?
.next()
.ok_or_else(|| "Failed to create user")??;
Ok(Self {
discord_id,
id: user_id,
is_pookie: false,
})
}
pub fn from_context(ctx: &impl MessageContext) -> Result<Self, Error> {
let id = ctx.author_id();
let user = ctx
.data()
.db