1
Fork 0

Better error handling

This commit is contained in:
prescientmoon 2024-09-23 19:46:53 +02:00
parent d7ac4094b2
commit 74f554e058
Signed by: prescientmoon
SSH key fingerprint: SHA256:WFp/cO76nbarETAoQcQXuV+0h7XJsEsOCI0UsyPIy6U
25 changed files with 978 additions and 681 deletions

464
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -2,6 +2,19 @@
name = "shimmeringmoon" name = "shimmeringmoon"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
autobins = false
[lib]
name = "shimmeringmoon"
path = "src/lib.rs"
[[bin]]
name = "shimmeringmoon-discord-bot"
path = "src/bin/discord-bot/main.rs"
[[bin]]
name = "shimmeringmoon-cli"
path = "src/bin/cli/main.rs"
[dependencies] [dependencies]
chrono = "0.4.38" chrono = "0.4.38"
@ -25,6 +38,8 @@ clap = { version = "4.5.17", features = ["derive"] }
postcard = { version = "1.0.10", features = ["use-std"], default-features = false } postcard = { version = "1.0.10", features = ["use-std"], default-features = false }
serde_with = "3.9.0" serde_with = "3.9.0"
anyhow = "1.0.87" anyhow = "1.0.87"
sha2 = "0.10.8"
base16ct = { version = "0.2.0", features = ["alloc"] }
[profile.dev.package."*"] # [profile.dev.package."*"]
opt-level = 3 # opt-level = 3

View file

@ -8,11 +8,11 @@
"rust-analyzer-src": "rust-analyzer-src" "rust-analyzer-src": "rust-analyzer-src"
}, },
"locked": { "locked": {
"lastModified": 1717827974, "lastModified": 1727073227,
"narHash": "sha256-ixopuTeTouxqTxfMuzs6IaRttbT8JqRW5C9Q/57WxQw=", "narHash": "sha256-1kmkEQmFfGVuPBasqSZrNThqyMDV1SzTalQdRZxtDRs=",
"owner": "nix-community", "owner": "nix-community",
"repo": "fenix", "repo": "fenix",
"rev": "ab655c627777ab5f9964652fe23bbb1dfbd687a8", "rev": "88cc292eb3c689073c784d6aecc0edbd47e12881",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -41,16 +41,16 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1718000748, "lastModified": 1726755586,
"narHash": "sha256-zliqz7ovpxYdKIK+GlWJZxifXsT9A1CHNQhLxV0G1Hc=", "narHash": "sha256-PmUr/2GQGvFTIJ6/Tvsins7Q43KTMvMFhvG6oaYK+Wk=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "869cab745a802b693b45d193b460c9184da671f3", "rev": "c04d5652cfa9742b1d519688f65d1bbccea9eb7e",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "nixos", "owner": "nixos",
"ref": "release-24.05", "ref": "nixos-unstable",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }
@ -59,17 +59,18 @@
"inputs": { "inputs": {
"fenix": "fenix", "fenix": "fenix",
"flake-utils": "flake-utils", "flake-utils": "flake-utils",
"nixpkgs": "nixpkgs" "nixpkgs": "nixpkgs",
"rust-overlay": "rust-overlay"
} }
}, },
"rust-analyzer-src": { "rust-analyzer-src": {
"flake": false, "flake": false,
"locked": { "locked": {
"lastModified": 1717583671, "lastModified": 1726443025,
"narHash": "sha256-+lRAmz92CNUxorqWusgJbL9VE1eKCnQQojglRemzwkw=", "narHash": "sha256-nCmG4NJpwI0IoIlYlwtDwVA49yuspA2E6OhfCOmiArQ=",
"owner": "rust-lang", "owner": "rust-lang",
"repo": "rust-analyzer", "repo": "rust-analyzer",
"rev": "48bbdd6a74f3176987d5c809894ac33957000d19", "rev": "94b526fc86eaa0e90fb4d54a5ba6313aa1e9b269",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -79,6 +80,26 @@
"type": "github" "type": "github"
} }
}, },
"rust-overlay": {
"inputs": {
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1727058553,
"narHash": "sha256-tY/UU3Qk5gP/J0uUM4DZ6wo4arNLGAVqLKBotILykfQ=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "edc5b0f896170f07bd39ad59d6186fcc7859bbb2",
"type": "github"
},
"original": {
"owner": "oxalica",
"repo": "rust-overlay",
"type": "github"
}
},
"systems": { "systems": {
"locked": { "locked": {
"lastModified": 1681028828, "lastModified": 1681028828,

View file

@ -1,17 +1,22 @@
{ {
inputs = { inputs = {
nixpkgs.url = "github:nixos/nixpkgs/release-24.05"; nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils"; flake-utils.url = "github:numtide/flake-utils";
fenix.url = "github:nix-community/fenix"; fenix.url = "github:nix-community/fenix";
fenix.inputs.nixpkgs.follows = "nixpkgs"; fenix.inputs.nixpkgs.follows = "nixpkgs";
rust-overlay.url = "github:oxalica/rust-overlay";
rust-overlay.inputs.nixpkgs.follows = "nixpkgs";
}; };
outputs = outputs =
{ ... }@inputs: inputs:
inputs.flake-utils.lib.eachSystem (with inputs.flake-utils.lib.system; [ x86_64-linux ]) ( inputs.flake-utils.lib.eachSystem (with inputs.flake-utils.lib.system; [ x86_64-linux ]) (
system: system:
let let
pkgs = inputs.nixpkgs.legacyPackages.${system}.extend inputs.fenix.overlays.default; pkgs = inputs.nixpkgs.legacyPackages.${system}.extend (import inputs.rust-overlay);
# toolchain = pkgs.rust-bin.selectLatestNightlyWith (toolchain: toolchain.default);
# toolchain = pkgs.rust-bin.stable.latest.default;
toolchain = inputs.fenix.packages.${system}.complete.toolchain;
inherit (pkgs) lib; inherit (pkgs) lib;
in in
{ {
@ -29,41 +34,35 @@
}; };
}; };
}; };
devShell = pkgs.mkShell rec { devShell = pkgs.mkShell {
packages = with pkgs; [ nativeBuildInputs = with pkgs; [
(fenix.complete.withComponents [ toolchain
"cargo" # ruff
"clippy" # imagemagick
"rust-src"
"rustc"
"rustfmt"
])
rust-analyzer-nightly
ruff
imagemagick
fontconfig
freetype
clang
llvmPackages.clang
pkg-config pkg-config
# clang
# llvmPackages.clang
];
buildInputs = with pkgs; [
toolchain
freetype
fontconfig
leptonica leptonica
tesseract tesseract
openssl # openssl
sqlite sqlite
]; ];
LD_LIBRARY_PATH = lib.makeLibraryPath packages; # LD_LIBRARY_PATH = lib.makeLibraryPath buildInputs;
# compilation of -sys packages requires manually setting LIBCLANG_PATH # compilation of -sys packages requires manually setting LIBCLANG_PATH
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; # LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
}; };
} }
); );
# {{{ Caching and whatnot # {{{ Caching and whatnot
# TODO: persist trusted substituters file
nixConfig = { nixConfig = {
extra-substituters = [ "https://nix-community.cachix.org" ]; extra-substituters = [ "https://nix-community.cachix.org" ];

View file

@ -3,7 +3,7 @@ use anyhow::anyhow;
use image::RgbaImage; use image::RgbaImage;
use crate::assets::get_data_dir; use crate::assets::get_data_dir;
use crate::context::{Error, UserContext}; use crate::context::{ErrorKind, TagError, TaggedError, UserContext};
use crate::user::User; use crate::user::User;
use super::chart::{Difficulty, Level}; use super::chart::{Difficulty, Level};
@ -119,9 +119,8 @@ impl GoalStats {
ctx: &UserContext, ctx: &UserContext,
user: &User, user: &User,
scoring_system: ScoringSystem, scoring_system: ScoringSystem,
) -> Result<Self, Error> { ) -> Result<Self, TaggedError> {
let plays = get_best_plays(ctx, user.id, scoring_system, 0, usize::MAX, None)? let plays = get_best_plays(ctx, user.id, scoring_system, 0, usize::MAX, None)?;
.map_err(|s| anyhow!("{s}"))?;
let conn = ctx.db.get()?; let conn = ctx.db.get()?;
// {{{ PM count // {{{ PM count
@ -141,14 +140,14 @@ impl GoalStats {
let peak_ptt = conn let peak_ptt = conn
.prepare_cached( .prepare_cached(
" "
SELECT s.creation_ptt SELECT s.creation_ptt
FROM plays p FROM plays p
JOIN scores s ON s.play_id = p.id JOIN scores s ON s.play_id = p.id
WHERE user_id = ? WHERE user_id = ?
AND scoring_system = ? AND scoring_system = ?
ORDER BY s.creation_ptt DESC ORDER BY s.creation_ptt DESC
LIMIT 1 LIMIT 1
", ",
)? )?
.query_row( .query_row(
( (
@ -157,7 +156,7 @@ impl GoalStats {
), ),
|row| row.get(0), |row| row.get(0),
) )
.map_err(|_| anyhow!("No ptt history data found"))?; .map_err(|_| anyhow!("No ptt history data found").tag(ErrorKind::User))?;
// }}} // }}}
// {{{ Peak PM relay // {{{ Peak PM relay
let peak_pm_relay = { let peak_pm_relay = {
@ -309,6 +308,7 @@ impl Default for AchievementTowers {
]); ]);
// }}} // }}}
// {{{ PTT tower // {{{ PTT tower
#[allow(clippy::zero_prefixed_literal)]
let ptt_tower = AchievementTower::new(vec![ let ptt_tower = AchievementTower::new(vec![
Achievement::new(PTT(0800)), Achievement::new(PTT(0800)),
Achievement::new(PTT(0900)), Achievement::new(PTT(0900)),

View file

@ -2,6 +2,7 @@
use std::array; use std::array;
use std::num::NonZeroU64; use std::num::NonZeroU64;
use anyhow::anyhow;
use anyhow::Context; use anyhow::Context;
use chrono::NaiveDateTime; use chrono::NaiveDateTime;
use chrono::Utc; use chrono::Utc;
@ -13,6 +14,9 @@ use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateEmbedAuthor,
use rusqlite::Row; use rusqlite::Row;
use crate::arcaea::chart::{Chart, Song}; use crate::arcaea::chart::{Chart, Song};
use crate::context::ErrorKind;
use crate::context::TagError;
use crate::context::TaggedError;
use crate::context::{Error, UserContext}; use crate::context::{Error, UserContext};
use crate::user::User; use crate::user::User;
@ -61,7 +65,7 @@ impl CreatePlay {
} }
// {{{ Save // {{{ Save
pub fn save(self, ctx: &UserContext, user: &User, chart: &Chart) -> Result<Play, Error> { pub fn save(self, ctx: &UserContext, user: &User, chart: &Chart) -> Result<Play, TaggedError> {
let conn = ctx.db.get()?; let conn = ctx.db.get()?;
let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64); let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64);
@ -104,9 +108,7 @@ impl CreatePlay {
for system in ScoringSystem::SCORING_SYSTEMS { for system in ScoringSystem::SCORING_SYSTEMS {
let i = system.to_index(); let i = system.to_index();
let plays = get_best_plays(ctx, user.id, system, 30, 30, None)?.ok(); let creation_ptt = try_compute_ptt(ctx, user.id, system, None)?;
let creation_ptt: Option<_> = try { rating_as_fixed(compute_b30_ptt(system, &plays?)) };
conn.prepare_cached( conn.prepare_cached(
" "
@ -321,10 +323,9 @@ impl Play {
self.score(ScoringSystem::Standard).0, self.score(ScoringSystem::Standard).0,
index index
); );
let icon_attachement = match chart.cached_jacket.as_ref() { let icon_attachement = chart
Some(jacket) => Some(CreateAttachment::bytes(jacket.raw, &attachement_name)), .cached_jacket
None => None, .map(|jacket| CreateAttachment::bytes(jacket.raw, &attachement_name));
};
let mut embed = CreateEmbed::default() let mut embed = CreateEmbed::default()
.title(format!( .title(format!(
@ -378,7 +379,7 @@ impl Play {
if let Some(max_recall) = self.max_recall { if let Some(max_recall) = self.max_recall {
format!("{}", max_recall) format!("{}", max_recall)
} else { } else {
format!("-") "-".to_string()
}, },
true, true,
) )
@ -409,14 +410,14 @@ impl Play {
// {{{ General functions // {{{ General functions
pub type PlayCollection<'a> = Vec<(Play, &'a Song, &'a Chart)>; pub type PlayCollection<'a> = Vec<(Play, &'a Song, &'a Chart)>;
pub fn get_best_plays<'a>( pub fn get_best_plays(
ctx: &'a UserContext, ctx: &UserContext,
user_id: u32, user_id: u32,
scoring_system: ScoringSystem, scoring_system: ScoringSystem,
min_amount: usize, min_amount: usize,
max_amount: usize, max_amount: usize,
before: Option<NaiveDateTime>, before: Option<NaiveDateTime>,
) -> Result<Result<PlayCollection<'a>, String>, Error> { ) -> Result<PlayCollection<'_>, TaggedError> {
let conn = ctx.db.get()?; let conn = ctx.db.get()?;
// {{{ DB data fetching // {{{ DB data fetching
let mut plays = conn let mut plays = conn
@ -453,10 +454,11 @@ pub fn get_best_plays<'a>(
// }}} // }}}
if plays.len() < min_amount { if plays.len() < min_amount {
return Ok(Err(format!( return Err(anyhow!(
"Not enough plays found ({} out of a minimum of {min_amount})", "Not enough plays found ({} out of a minimum of {min_amount})",
plays.len() plays.len()
))); )
.tag(crate::context::ErrorKind::User));
} }
// {{{ B30 computation // {{{ B30 computation
@ -464,7 +466,27 @@ pub fn get_best_plays<'a>(
plays.truncate(max_amount); plays.truncate(max_amount);
// }}} // }}}
Ok(Ok(plays)) Ok(plays)
}
/// Compute the current ptt of a given user.
///
/// This is similar to directly calling [get_best_plays] and then passing the
/// result into [compute_b30_ptt], except any user errors (i.e.: not enough
/// plays available) get turned into [None] values.
pub fn try_compute_ptt(
ctx: &UserContext,
user_id: u32,
system: ScoringSystem,
before: Option<NaiveDateTime>,
) -> Result<Option<i32>, Error> {
match get_best_plays(ctx, user_id, system, 30, 30, before) {
Err(err) => match err.kind {
ErrorKind::User => Ok(None),
ErrorKind::Internal => Err(err.error),
},
Ok(plays) => Ok(Some(rating_as_fixed(compute_b30_ptt(system, &plays)))),
}
} }
#[inline] #[inline]
@ -478,7 +500,7 @@ pub fn compute_b30_ptt(scoring_system: ScoringSystem, plays: &PlayCollection<'_>
} }
// }}} // }}}
// {{{ Maintenance functions // {{{ Maintenance functions
pub async fn generate_missing_scores(ctx: &UserContext) -> Result<(), Error> { pub async fn generate_missing_scores(ctx: &UserContext) -> Result<(), TaggedError> {
let conn = ctx.db.get()?; let conn = ctx.db.get()?;
let mut query = conn.prepare_cached( let mut query = conn.prepare_cached(
" "
@ -504,10 +526,8 @@ pub async fn generate_missing_scores(ctx: &UserContext) -> Result<(), Error> {
let play = play?; let play = play?;
for system in ScoringSystem::SCORING_SYSTEMS { for system in ScoringSystem::SCORING_SYSTEMS {
let i = system.to_index(); let i = system.to_index();
let plays = let creation_ptt = try_compute_ptt(ctx, play.user_id, system, Some(play.created_at))?;
get_best_plays(&ctx, play.user_id, system, 30, 30, Some(play.created_at))?.ok();
let creation_ptt: Option<_> = try { rating_as_fixed(compute_b30_ptt(system, &plays?)) };
let raw_score = play.scores.0[i].0; let raw_score = play.scores.0[i].0;
conn.prepare_cached( conn.prepare_cached(

View file

@ -1,7 +1,3 @@
pub mod analyse;
pub mod context;
pub mod prepare_jackets;
#[derive(clap::Parser)] #[derive(clap::Parser)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
pub struct Cli { pub struct Cli {
@ -11,8 +7,6 @@ pub struct Cli {
#[derive(clap::Subcommand)] #[derive(clap::Subcommand)]
pub enum Command { pub enum Command {
/// Start the discord bot
Discord {},
PrepareJackets {}, PrepareJackets {},
Analyse(analyse::Args), Analyse(crate::commands::analyse::Args),
} }

View file

@ -1,9 +1,9 @@
// {{{ Imports // {{{ Imports
use std::path::PathBuf; use std::path::PathBuf;
use crate::cli::context::CliContext; use crate::context::CliContext;
use crate::commands::score::magic_impl; use shimmeringmoon::commands::score::magic_impl;
use crate::context::{Error, UserContext}; use shimmeringmoon::context::{Error, UserContext};
// }}} // }}}
#[derive(clap::Args)] #[derive(clap::Args)]

View file

@ -0,0 +1,2 @@
pub mod analyse;
pub mod prepare_jackets;

View file

@ -5,11 +5,11 @@ use std::io::{stdout, Write};
use anyhow::{anyhow, bail, Context}; use anyhow::{anyhow, bail, Context};
use image::imageops::FilterType; use image::imageops::FilterType;
use crate::arcaea::chart::{Difficulty, SongCache}; use shimmeringmoon::arcaea::chart::{Difficulty, SongCache};
use crate::arcaea::jacket::{ImageVec, BITMAP_IMAGE_SIZE}; use shimmeringmoon::arcaea::jacket::{ImageVec, BITMAP_IMAGE_SIZE};
use crate::assets::{get_asset_dir, get_data_dir}; use shimmeringmoon::assets::{get_asset_dir, get_data_dir};
use crate::context::{connect_db, Error}; use shimmeringmoon::context::{connect_db, Error};
use crate::recognition::fuzzy_song_name::guess_chart_name; use shimmeringmoon::recognition::fuzzy_song_name::guess_chart_name;
// }}} // }}}
/// Hacky function which clears the current line of the standard output. /// Hacky function which clears the current line of the standard output.

View file

@ -5,9 +5,10 @@ use std::str::FromStr;
use poise::serenity_prelude::{CreateAttachment, CreateMessage}; use poise::serenity_prelude::{CreateAttachment, CreateMessage};
use crate::assets::get_var; extern crate shimmeringmoon;
use crate::context::Error; use shimmeringmoon::assets::get_var;
use crate::{commands::discord::MessageContext, context::UserContext}; use shimmeringmoon::context::Error;
use shimmeringmoon::{commands::discord::MessageContext, context::UserContext};
// }}} // }}}
/// Similar in scope to [crate::commands::discord::mock::MockContext], /// Similar in scope to [crate::commands::discord::mock::MockContext],

22
src/bin/cli/main.rs Normal file
View file

@ -0,0 +1,22 @@
use clap::Parser;
use command::{Cli, Command};
use shimmeringmoon::context::{Error, UserContext};
mod command;
mod commands;
mod context;
#[tokio::main]
async fn main() -> Result<(), Error> {
let cli = Cli::parse();
match cli.command {
Command::PrepareJackets {} => {
commands::prepare_jackets::run()?;
}
Command::Analyse(args) => {
commands::analyse::run(args).await?;
}
}
Ok(())
}

View file

@ -0,0 +1,88 @@
use poise::serenity_prelude::{self as serenity};
extern crate shimmeringmoon;
use shimmeringmoon::arcaea::play::generate_missing_scores;
use shimmeringmoon::context::{Error, UserContext};
use shimmeringmoon::{commands, timed};
use std::{env::var, sync::Arc, time::Duration};
// {{{ Error handler
async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) {
match error {
error => {
if let Err(e) = poise::builtins::on_error(error).await {
println!("Error while handling error: {}", e)
}
}
}
}
// }}}
#[tokio::main]
async fn main() {
// {{{ Poise options
let options = poise::FrameworkOptions {
commands: vec![
commands::help(),
commands::score::score(),
commands::stats::stats(),
commands::chart::chart(),
],
prefix_options: poise::PrefixFrameworkOptions {
stripped_dynamic_prefix: Some(|_ctx, message, _user_ctx| {
Box::pin(async {
if message.author.bot || Into::<u64>::into(message.author.id) == 1 {
Ok(None)
} else if message.content.starts_with("!") {
Ok(Some(message.content.split_at(1)))
} else if message.guild_id.is_none() {
if message.content.trim().len() == 0 {
Ok(Some(("", "score magic")))
} else {
Ok(Some(("", &message.content[..])))
}
} else {
Ok(None)
}
})
}),
edit_tracker: Some(Arc::new(poise::EditTracker::for_timespan(
Duration::from_secs(3600),
))),
..Default::default()
},
on_error: |error| Box::pin(on_error(error)),
..Default::default()
};
// }}}
// {{{ Start poise
let framework = poise::Framework::builder()
.setup(move |ctx, _ready, framework| {
Box::pin(async move {
println!("Logged in as {}", _ready.user.name);
poise::builtins::register_globally(ctx, &framework.options().commands).await?;
let ctx = UserContext::new().await?;
if var("SHIMMERING_REGEN_SCORES").unwrap_or_default() == "1" {
timed!("generate_missing_scores", {
generate_missing_scores(&ctx).await?;
});
}
Ok(ctx)
})
})
.options(options)
.build();
let token =
var("SHIMMERING_DISCORD_TOKEN").expect("Missing `SHIMMERING_DISCORD_TOKEN` env var");
let intents =
serenity::GatewayIntents::non_privileged() | serenity::GatewayIntents::MESSAGE_CONTENT;
let client = serenity::ClientBuilder::new(token, intents)
.framework(framework)
.await;
client.unwrap().start().await.unwrap()
// }}}
}

View file

@ -1,11 +1,11 @@
// {{{ Imports
use anyhow::anyhow; use anyhow::anyhow;
use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage}; // {{{ Imports
use poise::serenity_prelude::{CreateAttachment, CreateEmbed};
use crate::arcaea::{chart::Side, play::Play}; use crate::arcaea::{chart::Side, play::Play};
use crate::context::{Context, Error}; use crate::context::{Context, Error, ErrorKind, TagError, TaggedError};
use crate::get_user;
use crate::recognition::fuzzy_song_name::guess_song_and_chart; use crate::recognition::fuzzy_song_name::guess_song_and_chart;
use crate::user::User;
use std::io::Cursor; use std::io::Cursor;
use chrono::DateTime; use chrono::DateTime;
@ -20,7 +20,7 @@ use poise::CreateReply;
use crate::arcaea::score::{Score, ScoringSystem}; use crate::arcaea::score::{Score, ScoringSystem};
use super::discord::MessageContext; use super::discord::{CreateReplyExtra, MessageContext};
// }}} // }}}
// {{{ Top command // {{{ Top command
@ -37,14 +37,13 @@ pub async fn chart(_ctx: Context<'_>) -> Result<(), Error> {
// }}} // }}}
// {{{ Info // {{{ Info
// {{{ Implementation // {{{ Implementation
async fn info_impl(ctx: &mut impl MessageContext, name: &str) -> Result<(), Error> { async fn info_impl(ctx: &mut impl MessageContext, name: &str) -> Result<(), TaggedError> {
let (song, chart) = guess_song_and_chart(&ctx.data(), name)?; let (song, chart) = guess_song_and_chart(ctx.data(), name)?;
let attachement_name = "chart.png"; let attachement_name = "chart.png";
let icon_attachement = match chart.cached_jacket.as_ref() { let icon_attachement = chart
Some(jacket) => Some(CreateAttachment::bytes(jacket.raw, attachement_name)), .cached_jacket
None => None, .map(|jacket| CreateAttachment::bytes(jacket.raw, attachement_name));
};
let play_count: usize = ctx let play_count: usize = ctx
.data() .data()
@ -57,7 +56,8 @@ async fn info_impl(ctx: &mut impl MessageContext, name: &str) -> Result<(), Erro
WHERE chart_id=? WHERE chart_id=?
", ",
)? )?
.query_row([chart.id], |row| row.get(0))?; .query_row([chart.id], |row| row.get(0))
.unwrap_or(0);
let mut embed = CreateEmbed::default() let mut embed = CreateEmbed::default()
.title(format!( .title(format!(
@ -87,8 +87,13 @@ async fn info_impl(ctx: &mut impl MessageContext, name: &str) -> Result<(), Erro
embed = embed.thumbnail(format!("attachment://{}", &attachement_name)); embed = embed.thumbnail(format!("attachment://{}", &attachement_name));
} }
ctx.send_files(icon_attachement, CreateMessage::new().embed(embed)) ctx.send(
.await?; CreateReply::default()
.reply(true)
.embed(embed)
.attachments(icon_attachement),
)
.await?;
Ok(()) Ok(())
} }
@ -138,24 +143,19 @@ async fn info(
#[description = "Name of chart to show (difficulty at the end)"] #[description = "Name of chart to show (difficulty at the end)"]
name: String, name: String,
) -> Result<(), Error> { ) -> Result<(), Error> {
info_impl(&mut ctx, &name).await?; let res = info_impl(&mut ctx, &name).await;
ctx.handle_error(res).await?;
Ok(()) Ok(())
} }
// }}} // }}}
// }}} // }}}
// {{{ Best score // {{{ Best score
/// Show the best score on a given chart // {{{ Implementation
#[poise::command(prefix_command, slash_command, user_cooldown = 1)] async fn best_impl<C: MessageContext>(ctx: &mut C, name: &str) -> Result<Play, TaggedError> {
async fn best( let user = User::from_context(ctx)?;
mut ctx: Context<'_>,
#[rest]
#[description = "Name of chart to show (difficulty at the end)"]
name: String,
) -> Result<(), Error> {
let user = get_user!(&mut ctx);
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; let (song, chart) = guess_song_and_chart(ctx.data(), name)?;
let play = ctx let play = ctx
.data() .data()
.db .db
@ -181,6 +181,7 @@ async fn best(
song.title, song.title,
chart.difficulty chart.difficulty
) )
.tag(ErrorKind::User)
})?; })?;
let (embed, attachment) = play.to_embed( let (embed, attachment) = play.to_embed(
@ -192,27 +193,91 @@ async fn best(
Some(&ctx.fetch_user(&user.discord_id).await?), Some(&ctx.fetch_user(&user.discord_id).await?),
)?; )?;
ctx.channel_id() ctx.send(
.send_files(ctx.http(), attachment, CreateMessage::new().embed(embed)) CreateReply::default()
.await?; .reply(true)
.embed(embed)
.attachments(attachment),
)
.await?;
Ok(()) Ok(play)
} }
// }}} // }}}
// {{{ Score plot // {{{ Tests
// {{{ Tests
#[cfg(test)]
mod best_tests {
use std::path::PathBuf;
use crate::{
commands::{discord::mock::MockContext, score::magic_impl},
with_test_ctx,
};
use super::*;
#[tokio::test]
async fn no_scores() -> Result<(), Error> {
with_test_ctx!("test/commands/chart/best/specify_difficulty", async |ctx| {
best_impl(ctx, "Pentiment").await?;
Ok(())
})
}
#[tokio::test]
async fn pick_correct_score() -> Result<(), Error> {
with_test_ctx!(
"test/commands/chart/best/last_byd",
async |ctx: &mut MockContext| {
magic_impl(
ctx,
&[
PathBuf::from_str("test/screenshots/fracture_ray_ex.jpg")?,
// Make sure we aren't considering higher scores from other stuff
PathBuf::from_str("test/screenshots/antithese_74_kerning.jpg")?,
PathBuf::from_str("test/screenshots/fracture_ray_missed_ex.jpg")?,
],
)
.await?;
let play = best_impl(ctx, "Fracture ray").await?;
assert_eq!(play.score(ScoringSystem::Standard).0, 9_805_651);
Ok(())
}
)
}
}
// }}}
// }}}
// {{{ Discord wrapper
/// Show the best score on a given chart /// Show the best score on a given chart
#[poise::command(prefix_command, slash_command, user_cooldown = 10)] #[poise::command(prefix_command, slash_command, user_cooldown = 1)]
async fn plot( async fn best(
mut ctx: Context<'_>, mut ctx: Context<'_>,
scoring_system: Option<ScoringSystem>,
#[rest] #[rest]
#[description = "Name of chart to show (difficulty at the end)"] #[description = "Name of chart to show (difficulty at the end)"]
name: String, name: String,
) -> Result<(), Error> { ) -> Result<(), Error> {
let user = get_user!(&mut ctx); let res = best_impl(&mut ctx, &name).await;
ctx.handle_error(res).await?;
Ok(())
}
// }}}
// }}}
// {{{ Score plot
// {{{ Implementation
async fn plot_impl<C: MessageContext>(
ctx: &mut C,
scoring_system: Option<ScoringSystem>,
name: String,
) -> Result<(), TaggedError> {
let user = User::from_context(ctx)?;
let scoring_system = scoring_system.unwrap_or_default(); let scoring_system = scoring_system.unwrap_or_default();
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; let (song, chart) = guess_song_and_chart(ctx.data(), &name)?;
// SAFETY: we limit the amount of plotted plays to 1000. // SAFETY: we limit the amount of plotted plays to 1000.
let plays = ctx let plays = ctx
@ -236,13 +301,11 @@ async fn plot(
.query_map((user.id, chart.id), |row| Play::from_sql(chart, row))? .query_map((user.id, chart.id), |row| Play::from_sql(chart, row))?
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
if plays.len() == 0 { if plays.is_empty() {
ctx.reply(format!( return Err(
"No plays found on {} [{:?}]", anyhow!("No plays found on {} [{:?}]", song.title, chart.difficulty)
song.title, chart.difficulty .tag(ErrorKind::User),
)) );
.await?;
return Ok(());
} }
let min_time = plays.iter().map(|p| p.created_at).min().unwrap(); let min_time = plays.iter().map(|p| p.created_at).min().unwrap();
@ -255,7 +318,7 @@ async fn plot(
.0 as i64; .0 as i64;
if min_score > 9_900_000 { if min_score > 9_900_000 {
min_score = 9_800_000; min_score = 9_900_000;
} else if min_score > 9_800_000 { } else if min_score > 9_800_000 {
min_score = 9_800_000; min_score = 9_800_000;
} else if min_score > 9_500_000 { } else if min_score > 9_500_000 {
@ -331,9 +394,28 @@ async fn plot(
let mut cursor = Cursor::new(&mut buffer); let mut cursor = Cursor::new(&mut buffer);
image.write_to(&mut cursor, image::ImageFormat::Png)?; image.write_to(&mut cursor, image::ImageFormat::Png)?;
let reply = CreateReply::default().attachment(CreateAttachment::bytes(buffer, "plot.png")); let reply = CreateReply::default()
.reply(true)
.attachment(CreateAttachment::bytes(buffer, "plot.png"));
ctx.send(reply).await?; ctx.send(reply).await?;
Ok(()) Ok(())
} }
// }}} // }}}
// {{{ Discord wrapper
/// Show the best score on a given chart
#[poise::command(prefix_command, slash_command, user_cooldown = 10)]
async fn plot(
mut ctx: Context<'_>,
scoring_system: Option<ScoringSystem>,
#[rest]
#[description = "Name of chart to show (difficulty at the end)"]
name: String,
) -> Result<(), Error> {
let res = plot_impl(&mut ctx, scoring_system, name).await;
ctx.handle_error(res).await?;
Ok(())
}
// }}}
// }}}

View file

@ -3,10 +3,11 @@ use std::num::NonZeroU64;
use std::str::FromStr; use std::str::FromStr;
use poise::serenity_prelude::futures::future::join_all; use poise::serenity_prelude::futures::future::join_all;
use poise::serenity_prelude::{CreateAttachment, CreateMessage}; use poise::serenity_prelude::{CreateAttachment, CreateEmbed};
use poise::CreateReply;
use crate::arcaea::play::Play; use crate::arcaea::play::Play;
use crate::context::{Error, UserContext}; use crate::context::{Error, ErrorKind, TaggedError, UserContext};
use crate::timed; use crate::timed;
// }}} // }}}
@ -22,17 +23,8 @@ pub trait MessageContext {
/// Reply to the current message /// Reply to the current message
async fn reply(&mut self, text: &str) -> Result<(), Error>; async fn reply(&mut self, text: &str) -> Result<(), Error>;
/// Deliver a message containing references to files.
async fn send_files(
&mut self,
attachments: impl IntoIterator<Item = CreateAttachment>,
message: CreateMessage,
) -> Result<(), Error>;
/// Deliver a message /// Deliver a message
async fn send(&mut self, message: CreateMessage) -> Result<(), Error> { async fn send(&mut self, message: CreateReply) -> Result<(), Error>;
self.send_files([], message).await
}
// {{{ Input attachments // {{{ Input attachments
type Attachment; type Attachment;
@ -61,6 +53,20 @@ pub trait MessageContext {
.collect::<Result<_, Error>>() .collect::<Result<_, Error>>()
} }
// }}} // }}}
// {{{ Erorr handling
async fn handle_error<V>(&mut self, res: Result<V, TaggedError>) -> Result<Option<V>, Error> {
match res {
Ok(v) => Ok(Some(v)),
Err(e) => match e.kind {
ErrorKind::Internal => Err(e.error),
ErrorKind::User => {
self.reply(&format!("{}", e.error)).await?;
Ok(None)
}
},
}
}
// }}}
} }
// }}} // }}}
// {{{ Poise implementation // {{{ Poise implementation
@ -87,14 +93,8 @@ impl<'a> MessageContext for poise::Context<'a, UserContext, Error> {
Ok(()) Ok(())
} }
async fn send_files( async fn send(&mut self, message: CreateReply) -> Result<(), Error> {
&mut self, poise::send_reply(*self, message).await?;
attachments: impl IntoIterator<Item = CreateAttachment>,
message: CreateMessage,
) -> Result<(), Error> {
self.channel_id()
.send_files(self.http(), attachments, message)
.await?;
Ok(()) Ok(())
} }
@ -122,6 +122,10 @@ impl<'a> MessageContext for poise::Context<'a, UserContext, Error> {
pub mod mock { pub mod mock {
use std::{env, fs, path::PathBuf}; use std::{env, fs, path::PathBuf};
use poise::serenity_prelude::CreateEmbed;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use super::*; use super::*;
/// A mock context usable for testing. Messages and attachments are /// A mock context usable for testing. Messages and attachments are
@ -130,7 +134,26 @@ pub mod mock {
pub struct MockContext { pub struct MockContext {
pub user_id: u64, pub user_id: u64,
pub data: UserContext, pub data: UserContext,
pub messages: Vec<(CreateMessage, Vec<CreateAttachment>)>, messages: Vec<ReplyEssence>,
}
/// Holds test-relevant data about an attachment.
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AttachmentEssence {
filename: String,
description: Option<String>,
/// SHA-256 hash of the file
hash: String,
}
/// Holds test-relevant data about a reply.
#[derive(Debug, Clone, Serialize)]
struct ReplyEssence {
reply: bool,
ephermal: Option<bool>,
content: Option<String>,
embeds: Vec<CreateEmbed>,
attachments: Vec<AttachmentEssence>,
} }
impl MockContext { impl MockContext {
@ -157,10 +180,8 @@ pub mod mock {
} }
fs::create_dir_all(path)?; fs::create_dir_all(path)?;
for (i, (message, attachments)) in self.messages.iter().enumerate() { for (i, message) in self.messages.iter().enumerate() {
let dir = path.join(format!("{i}")); let message_file = path.join(format!("{i}.toml"));
fs::create_dir_all(&dir)?;
let message_file = dir.join("message.toml");
if message_file.exists() { if message_file.exists() {
assert_eq!( assert_eq!(
@ -170,28 +191,6 @@ pub mod mock {
} else { } else {
fs::write(&message_file, toml::to_string_pretty(message)?)?; fs::write(&message_file, toml::to_string_pretty(message)?)?;
} }
for attachment in attachments {
let path = dir.join(&attachment.filename);
if path.exists() {
if &attachment.data != &fs::read(&path)? {
panic!("Attachment differs from {path:?}");
}
} else {
fs::write(&path, &attachment.data)?;
}
}
// Ensure there's no extra attachments on disk
let file_count = fs::read_dir(dir)?.count();
if file_count != attachments.len() + 1 {
panic!(
"Only {} attachments found instead of {}",
attachments.len(),
file_count - 1
);
}
} }
Ok(()) Ok(())
@ -219,18 +218,33 @@ pub mod mock {
} }
async fn reply(&mut self, text: &str) -> Result<(), Error> { async fn reply(&mut self, text: &str) -> Result<(), Error> {
self.messages self.send(CreateReply::default().content(text).reply(true))
.push((CreateMessage::new().content(text), Vec::new())); .await
Ok(())
} }
async fn send_files( async fn send(&mut self, message: CreateReply) -> Result<(), Error> {
&mut self, self.messages.push(ReplyEssence {
attachments: impl IntoIterator<Item = CreateAttachment>, reply: message.reply,
message: CreateMessage, ephermal: message.ephemeral,
) -> Result<(), Error> { content: message.content,
self.messages embeds: message.embeds,
.push((message, attachments.into_iter().collect())); attachments: message
.attachments
.into_iter()
.map(|attachment| AttachmentEssence {
filename: attachment.filename,
description: attachment.description,
hash: {
let hash = Sha256::digest(&attachment.data);
let string = base16ct::lower::encode_string(&hash);
// We allocate twice, but it's only at the end of tests,
// so it should be fineeeeeeee
format!("sha256_{string}")
},
})
.collect(),
});
Ok(()) Ok(())
} }
@ -265,4 +279,27 @@ pub mod mock {
pub fn play_song_title<'a>(ctx: &'a impl MessageContext, play: &'a Play) -> Result<&'a str, Error> { pub fn play_song_title<'a>(ctx: &'a impl MessageContext, play: &'a Play) -> Result<&'a str, Error> {
Ok(&ctx.data().song_cache.lookup_chart(play.chart_id)?.0.title) Ok(&ctx.data().song_cache.lookup_chart(play.chart_id)?.0.title)
} }
pub trait CreateReplyExtra {
fn attachments(self, attachments: impl IntoIterator<Item = CreateAttachment>) -> Self;
fn embeds(self, embeds: impl IntoIterator<Item = CreateEmbed>) -> Self;
}
impl CreateReplyExtra for CreateReply {
fn attachments(mut self, attachments: impl IntoIterator<Item = CreateAttachment>) -> Self {
for attachment in attachments.into_iter() {
self = self.attachment(attachment);
}
self
}
fn embeds(mut self, embeds: impl IntoIterator<Item = CreateEmbed>) -> Self {
for embed in embeds.into_iter() {
self = self.embed(embed);
}
self
}
}
// }}} // }}}

View file

@ -33,7 +33,7 @@ pub async fn help(
/// Explains the different scoring systems /// Explains the different scoring systems
#[poise::command(prefix_command, slash_command)] #[poise::command(prefix_command, slash_command)]
async fn scoring(ctx: Context<'_>) -> Result<(), Error> { async fn scoring(ctx: Context<'_>) -> Result<(), Error> {
static CONTENT: &'static str = " static CONTENT: &str = "
## 1. Standard scoring (`standard`): ## 1. Standard scoring (`standard`):
This is the base-game Arcaea scoring system we all know and love! Points are awarded for each note, with a `2:1` pure:far ratio. The score is then scaled up such that `10_000_000` is the maximum. Last but not least, the number of max pures is added to the total. This is the base-game Arcaea scoring system we all know and love! Points are awarded for each note, with a `2:1` pure:far ratio. The score is then scaled up such that `10_000_000` is the maximum. Last but not least, the number of max pures is added to the total.
@ -58,7 +58,7 @@ Most commands take an optional parameter specifying what scoring system to use.
/// Explains the different scoring systems using gen-z slang /// Explains the different scoring systems using gen-z slang
#[poise::command(prefix_command, slash_command)] #[poise::command(prefix_command, slash_command)]
async fn scoringz(ctx: Context<'_>) -> Result<(), Error> { async fn scoringz(ctx: Context<'_>) -> Result<(), Error> {
static CONTENT: &'static str = " static CONTENT: &str = "
## 1. Standard scoring (`standard`): ## 1. Standard scoring (`standard`):
Alright, fam, this is the OG Arcaea scoring setup that everyone vibes with! You hit notes, you get points easy clap. The ratio is straight up `2:1` pure:far. The score then gets a glow-up, maxing out at `10 milly`. And hold up, you even get bonus points for those max pures at the end. No cap, this is the classic way to flex your skills. Alright, fam, this is the OG Arcaea scoring setup that everyone vibes with! You hit notes, you get points easy clap. The ratio is straight up `2:1` pure:far. The score then gets a glow-up, maxing out at `10 milly`. And hold up, you even get bonus points for those max pures at the end. No cap, this is the classic way to flex your skills.

View file

@ -1,16 +1,15 @@
// {{{ Imports // {{{ Imports
use crate::arcaea::play::{CreatePlay, Play}; use crate::arcaea::play::{CreatePlay, Play};
use crate::arcaea::score::Score; use crate::arcaea::score::Score;
use crate::context::{Context, Error}; use crate::context::{Context, Error, ErrorKind, TagError, TaggedError};
use crate::recognition::recognize::{ImageAnalyzer, ScoreKind}; use crate::recognition::recognize::{ImageAnalyzer, ScoreKind};
use crate::user::User; use crate::user::User;
use crate::{get_user, timed}; use crate::{get_user_error, timed};
use anyhow::anyhow; use anyhow::anyhow;
use image::DynamicImage; use image::DynamicImage;
use poise::serenity_prelude as serenity; use poise::{serenity_prelude as serenity, CreateReply};
use poise::serenity_prelude::CreateMessage;
use super::discord::MessageContext; use super::discord::{CreateReplyExtra, MessageContext};
// }}} // }}}
// {{{ Score // {{{ Score
@ -30,13 +29,12 @@ pub async fn score(_ctx: Context<'_>) -> Result<(), Error> {
pub async fn magic_impl<C: MessageContext>( pub async fn magic_impl<C: MessageContext>(
ctx: &mut C, ctx: &mut C,
files: &[C::Attachment], files: &[C::Attachment],
) -> Result<Vec<Play>, Error> { ) -> Result<Vec<Play>, TaggedError> {
let user = get_user!(ctx); let user = User::from_context(ctx)?;
let files = ctx.download_images(&files).await?; let files = ctx.download_images(files).await?;
if files.len() == 0 { if files.is_empty() {
ctx.reply("No images found attached to message").await?; return Err(anyhow!("No images found attached to message").tag(ErrorKind::User));
return Ok(vec![]);
} }
let mut embeds = Vec::with_capacity(files.len()); let mut embeds = Vec::with_capacity(files.len());
@ -50,7 +48,7 @@ pub async fn magic_impl<C: MessageContext>(
let mut grayscale_image = DynamicImage::ImageLuma8(image.to_luma8()); let mut grayscale_image = DynamicImage::ImageLuma8(image.to_luma8());
// }}} // }}}
let result: Result<(), Error> = try { let result: Result<(), TaggedError> = try {
// {{{ Detection // {{{ Detection
let kind = timed!("read_score_kind", { let kind = timed!("read_score_kind", {
@ -102,13 +100,13 @@ pub async fn magic_impl<C: MessageContext>(
.with_attachment(C::attachment_id(attachment)) .with_attachment(C::attachment_id(attachment))
.with_fars(maybe_fars) .with_fars(maybe_fars)
.with_max_recall(max_recall) .with_max_recall(max_recall)
.save(&ctx.data(), &user, &chart)?; .save(ctx.data(), &user, chart)?;
// }}} // }}}
// }}} // }}}
// {{{ Deliver embed // {{{ Deliver embed
let (embed, attachment) = timed!("to embed", { let (embed, attachment) = timed!("to embed", {
play.to_embed(ctx.data(), &user, &song, &chart, i, None)? play.to_embed(ctx.data(), &user, song, chart, i, None)?
}); });
plays.push(play); plays.push(play);
@ -118,15 +116,21 @@ pub async fn magic_impl<C: MessageContext>(
}; };
if let Err(err) = result { if let Err(err) = result {
let user_err = get_user_error!(err);
analyzer analyzer
.send_discord_error(ctx, &image, C::filename(&attachment), err) .send_discord_error(ctx, &image, C::filename(attachment), user_err)
.await?; .await?;
} }
} }
if embeds.len() > 0 { if !embeds.is_empty() {
ctx.send_files(attachments, CreateMessage::new().embeds(embeds)) ctx.send(
.await?; CreateReply::default()
.reply(true)
.embeds(embeds)
.attachments(attachments),
)
.await?;
} }
Ok(plays) Ok(plays)
@ -203,7 +207,8 @@ pub async fn magic(
mut ctx: Context<'_>, mut ctx: Context<'_>,
#[description = "Images containing scores"] files: Vec<serenity::Attachment>, #[description = "Images containing scores"] files: Vec<serenity::Attachment>,
) -> Result<(), Error> { ) -> Result<(), Error> {
magic_impl(&mut ctx, &files).await?; let res = magic_impl(&mut ctx, &files).await;
ctx.handle_error(res).await?;
Ok(()) Ok(())
} }
@ -211,10 +216,12 @@ pub async fn magic(
// }}} // }}}
// {{{ Score show // {{{ Score show
// {{{ Implementation // {{{ Implementation
pub async fn show_impl<C: MessageContext>(ctx: &mut C, ids: &[u32]) -> Result<Vec<Play>, Error> { pub async fn show_impl<C: MessageContext>(
if ids.len() == 0 { ctx: &mut C,
ctx.reply("Empty ID list provided").await?; ids: &[u32],
return Ok(vec![]); ) -> Result<Vec<Play>, TaggedError> {
if ids.is_empty() {
return Err(anyhow!("Empty ID list provided").tag(ErrorKind::User));
} }
let mut embeds = Vec::with_capacity(ids.len()); let mut embeds = Vec::with_capacity(ids.len());
@ -250,7 +257,7 @@ pub async fn show_impl<C: MessageContext>(ctx: &mut C, ids: &[u32]) -> Result<Ve
let (song, chart, play, discord_id) = match result { let (song, chart, play, discord_id) = match result {
None => { None => {
ctx.send( ctx.send(
CreateMessage::new().content(format!("Could not find play with id {}", id)), CreateReply::default().content(format!("Could not find play with id {}", id)),
) )
.await?; .await?;
continue; continue;
@ -269,9 +276,14 @@ pub async fn show_impl<C: MessageContext>(ctx: &mut C, ids: &[u32]) -> Result<Ve
plays.push(play); plays.push(play);
} }
if embeds.len() > 0 { if !embeds.is_empty() {
ctx.send_files(attachments, CreateMessage::new().embeds(embeds)) ctx.send(
.await?; CreateReply::default()
.reply(true)
.embeds(embeds)
.attachments(attachments),
)
.await?;
} }
Ok(plays) Ok(plays)
@ -333,7 +345,8 @@ pub async fn show(
mut ctx: Context<'_>, mut ctx: Context<'_>,
#[description = "Ids of score to show"] ids: Vec<u32>, #[description = "Ids of score to show"] ids: Vec<u32>,
) -> Result<(), Error> { ) -> Result<(), Error> {
show_impl(&mut ctx, &ids).await?; let res = show_impl(&mut ctx, &ids).await;
ctx.handle_error(res).await?;
Ok(()) Ok(())
} }
@ -341,12 +354,11 @@ pub async fn show(
// }}} // }}}
// {{{ Score delete // {{{ Score delete
// {{{ Implementation // {{{ Implementation
pub async fn delete_impl<C: MessageContext>(ctx: &mut C, ids: &[u32]) -> Result<(), Error> { pub async fn delete_impl<C: MessageContext>(ctx: &mut C, ids: &[u32]) -> Result<(), TaggedError> {
let user = get_user!(ctx); let user = User::from_context(ctx)?;
if ids.len() == 0 { if ids.is_empty() {
ctx.reply("Empty ID list provided").await?; return Err(anyhow!("Empty ID list provided").tag(ErrorKind::User));
return Ok(());
} }
let mut count = 0; let mut count = 0;
@ -472,7 +484,8 @@ pub async fn delete(
mut ctx: Context<'_>, mut ctx: Context<'_>,
#[description = "Id of score to delete"] ids: Vec<u32>, #[description = "Id of score to delete"] ids: Vec<u32>,
) -> Result<(), Error> { ) -> Result<(), Error> {
delete_impl(&mut ctx, &ids).await?; let res = delete_impl(&mut ctx, &ids).await;
ctx.handle_error(res).await?;
Ok(()) Ok(())
} }

View file

@ -18,10 +18,11 @@ use crate::assets::{
TOP_BACKGROUND, TOP_BACKGROUND,
}; };
use crate::bitmap::{Align, BitmapCanvas, Color, LayoutDrawer, LayoutManager, Rect}; use crate::bitmap::{Align, BitmapCanvas, Color, LayoutDrawer, LayoutManager, Rect};
use crate::context::{Context, Error}; use crate::context::{Context, Error, TaggedError};
use crate::logs::debug_image_log; use crate::logs::debug_image_log;
use crate::user::User; use crate::user::User;
use crate::{assert_is_pookie, get_user, reply_errors, timed};
use super::discord::MessageContext;
// }}} // }}}
// {{{ Stats // {{{ Stats
@ -37,31 +38,26 @@ pub async fn stats(_ctx: Context<'_>) -> Result<(), Error> {
} }
// }}} // }}}
// {{{ Render best plays // {{{ Render best plays
async fn best_plays( async fn best_plays<C: MessageContext>(
ctx: &mut Context<'_>, ctx: &mut C,
user: &User, user: &User,
scoring_system: ScoringSystem, scoring_system: ScoringSystem,
grid_size: (u32, u32), grid_size: (u32, u32),
require_full: bool, require_full: bool,
) -> Result<(), Error> { ) -> Result<(), TaggedError> {
let user_ctx = ctx.data(); let user_ctx = ctx.data();
let plays = reply_errors!( let plays = get_best_plays(
ctx, user_ctx,
timed!("get_best_plays", { user.id,
get_best_plays( scoring_system,
user_ctx, if require_full {
user.id, grid_size.0 * grid_size.1
scoring_system, } else {
if require_full { grid_size.0 * (grid_size.1.max(1) - 1) + 1
grid_size.0 * grid_size.1 } as usize,
} else { (grid_size.0 * grid_size.1) as usize,
grid_size.0 * (grid_size.1.max(1) - 1) + 1 None,
} as usize, )?;
(grid_size.0 * grid_size.1) as usize,
None,
)?
})
);
// {{{ Layout // {{{ Layout
let mut layout = LayoutManager::default(); let mut layout = LayoutManager::default();
@ -132,7 +128,7 @@ async fn best_plays(
let bg_center = Rect::from_image(bg).center(); let bg_center = Rect::from_image(bg).center();
// Draw background // Draw background
drawer.blit_rbga(item_area, (-8, jacket_margin as i32), bg); drawer.blit_rbga(item_area, (-8, jacket_margin), bg);
with_font(&EXO_FONT, |faces| { with_font(&EXO_FONT, |faces| {
drawer.text( drawer.text(
item_area, item_area,
@ -420,20 +416,49 @@ async fn best_plays(
} }
// }}} // }}}
// {{{ B30 // {{{ B30
// {{{ Implementation
pub async fn b30_impl<C: MessageContext>(
ctx: &mut C,
scoring_system: Option<ScoringSystem>,
) -> Result<(), TaggedError> {
let user = User::from_context(ctx)?;
best_plays(ctx, &user, scoring_system.unwrap_or_default(), (5, 6), true).await?;
Ok(())
}
// }}}
// {{{ Discord wrapper
/// Show the 30 best scores /// Show the 30 best scores
#[poise::command(prefix_command, slash_command, user_cooldown = 30)] #[poise::command(prefix_command, slash_command, user_cooldown = 30)]
pub async fn b30(mut ctx: Context<'_>, scoring_system: Option<ScoringSystem>) -> Result<(), Error> { pub async fn b30(mut ctx: Context<'_>, scoring_system: Option<ScoringSystem>) -> Result<(), Error> {
let user = get_user!(&mut ctx); let res = b30_impl(&mut ctx, scoring_system).await;
ctx.handle_error(res).await?;
Ok(())
}
// }}}
// }}}
// {{{ B-any
// {{{ Implementation
async fn bany_impl<C: MessageContext>(
ctx: &mut C,
scoring_system: Option<ScoringSystem>,
width: u32,
height: u32,
) -> Result<(), TaggedError> {
let user = User::from_context(ctx)?;
user.assert_is_pookie()?;
best_plays( best_plays(
&mut ctx, ctx,
&user, &user,
scoring_system.unwrap_or_default(), scoring_system.unwrap_or_default(),
(5, 6), (width, height),
true, false,
) )
.await .await?;
}
Ok(())
}
// }}}
// {{{ Discord wrapper
#[poise::command(prefix_command, slash_command, hide_in_help, global_cooldown = 5)] #[poise::command(prefix_command, slash_command, hide_in_help, global_cooldown = 5)]
pub async fn bany( pub async fn bany(
mut ctx: Context<'_>, mut ctx: Context<'_>,
@ -441,23 +466,16 @@ pub async fn bany(
width: u32, width: u32,
height: u32, height: u32,
) -> Result<(), Error> { ) -> Result<(), Error> {
let user = get_user!(&mut ctx); let res = bany_impl(&mut ctx, scoring_system, width, height).await;
assert_is_pookie!(ctx, user); ctx.handle_error(res).await?;
best_plays( Ok(())
&mut ctx,
&user,
scoring_system.unwrap_or_default(),
(width, height),
false,
)
.await
} }
// }}} // }}}
// }}}
// {{{ Meta // {{{ Meta
/// Show stats about the bot itself. // {{{ Implementation
#[poise::command(prefix_command, slash_command, user_cooldown = 1)] async fn meta_impl<C: MessageContext>(ctx: &mut C) -> Result<(), TaggedError> {
async fn meta(mut ctx: Context<'_>) -> Result<(), Error> { let user = User::from_context(ctx)?;
let user = get_user!(&mut ctx);
let conn = ctx.data().db.get()?; let conn = ctx.data().db.get()?;
let song_count: usize = conn let song_count: usize = conn
.prepare_cached("SELECT count() as count FROM songs")? .prepare_cached("SELECT count() as count FROM songs")?
@ -504,8 +522,10 @@ async fn meta(mut ctx: Context<'_>) -> Result<(), Error> {
.field("Plays", format!("{play_count}"), true) .field("Plays", format!("{play_count}"), true)
.field("Your plays", format!("{your_play_count}"), true); .field("Your plays", format!("{your_play_count}"), true);
ctx.send(CreateReply::default().embed(embed)).await?; ctx.send(CreateReply::default().reply(true).embed(embed))
.await?;
// TODO: remove once achivement system is implemented
println!( println!(
"{:?}", "{:?}",
GoalStats::make(ctx.data(), &user, ScoringSystem::Standard).await? GoalStats::make(ctx.data(), &user, ScoringSystem::Standard).await?
@ -514,3 +534,14 @@ async fn meta(mut ctx: Context<'_>) -> Result<(), Error> {
Ok(()) Ok(())
} }
// }}} // }}}
// {{{ Discord wrapper
/// Show stats about the bot itself.
#[poise::command(prefix_command, slash_command, user_cooldown = 1)]
async fn meta(mut ctx: Context<'_>) -> Result<(), Error> {
let res = meta_impl(&mut ctx).await;
ctx.handle_error(res).await?;
Ok(())
}
// }}}
// }}}

View file

@ -10,37 +10,3 @@ macro_rules! edit_reply {
$handle.edit($ctx, edited) $handle.edit($ctx, edited)
}}; }};
} }
#[macro_export]
macro_rules! get_user {
($ctx:expr) => {{
crate::reply_errors!($ctx, crate::user::User::from_context($ctx))
}};
}
#[macro_export]
macro_rules! assert_is_pookie {
($ctx:expr, $user:expr) => {{
if !$user.is_pookie {
$ctx.reply("This feature is reserved for my pookies. Sowwy :3")
.await?;
return Ok(());
}
}};
}
#[macro_export]
macro_rules! reply_errors {
($ctx:expr, $default:expr, $value:expr) => {
match $value {
Ok(v) => v,
Err(err) => {
crate::commands::discord::MessageContext::reply($ctx, &format!("{err}")).await?;
return Ok($default);
}
}
};
($ctx:expr, $value:expr) => {
crate::reply_errors!($ctx, Default::default(), $value)
};
}

View file

@ -9,6 +9,7 @@ use std::sync::LazyLock;
use crate::arcaea::{chart::SongCache, jacket::JacketCache}; use crate::arcaea::{chart::SongCache, jacket::JacketCache};
use crate::assets::{get_data_dir, EXO_FONT, GEOSANS_FONT, KAZESAWA_BOLD_FONT, KAZESAWA_FONT}; use crate::assets::{get_data_dir, EXO_FONT, GEOSANS_FONT, KAZESAWA_BOLD_FONT, KAZESAWA_FONT};
use crate::commands::discord::MessageContext;
use crate::recognition::{hyperglass::CharMeasurements, ui::UIMeasurements}; use crate::recognition::{hyperglass::CharMeasurements, ui::UIMeasurements};
use crate::timed; use crate::timed;
// }}} // }}}
@ -17,6 +18,70 @@ use crate::timed;
pub type Error = anyhow::Error; pub type Error = anyhow::Error;
pub type Context<'a> = poise::Context<'a, UserContext, Error>; pub type Context<'a> = poise::Context<'a, UserContext, Error>;
// }}} // }}}
// {{{ Error handling
#[derive(Debug, Clone, Copy)]
pub enum ErrorKind {
User,
Internal,
}
#[derive(Debug)]
pub struct TaggedError {
pub kind: ErrorKind,
pub error: Error,
}
impl TaggedError {
#[inline]
pub fn new(kind: ErrorKind, error: Error) -> Self {
Self { kind, error }
}
}
#[macro_export]
macro_rules! get_user_error {
($err:expr) => {{
match $err.kind {
$crate::context::ErrorKind::User => $err.error,
$crate::context::ErrorKind::Internal => Err($err.error)?,
}
}};
}
/// Handles a [TaggedError], showing user errors to the user,
/// and throwing away anything else.
pub async fn discord_error_handler<V>(
ctx: &mut impl MessageContext,
res: Result<V, TaggedError>,
) -> Result<Option<V>, Error> {
match res {
Ok(v) => Ok(Some(v)),
Err(e) => match e.kind {
ErrorKind::Internal => Err(e.error),
ErrorKind::User => {
ctx.reply(&format!("{}", e.error)).await?;
Ok(None)
}
},
}
}
impl<E: Into<Error>> From<E> for TaggedError {
fn from(value: E) -> Self {
Self::new(ErrorKind::Internal, value.into())
}
}
pub trait TagError {
fn tag(self, tag: ErrorKind) -> TaggedError;
}
impl TagError for Error {
fn tag(self, tag: ErrorKind) -> TaggedError {
TaggedError::new(tag, self)
}
}
// }}}
// {{{ DB connection // {{{ DB connection
pub type DbConnection = r2d2::Pool<SqliteConnectionManager>; pub type DbConnection = r2d2::Pool<SqliteConnectionManager>;
@ -106,7 +171,7 @@ pub mod testing {
.await .await
} }
pub fn import_songs_and_jackets_from(to: &Path) -> () { pub fn import_songs_and_jackets_from(to: &Path) {
let out = std::process::Command::new("scripts/copy-chart-info.sh") let out = std::process::Command::new("scripts/copy-chart-info.sh")
.arg(get_data_dir()) .arg(get_data_dir())
.arg(to) .arg(to)
@ -124,16 +189,17 @@ pub mod testing {
($test_path:expr, $f:expr) => {{ ($test_path:expr, $f:expr) => {{
use std::str::FromStr; use std::str::FromStr;
let mut data = (*crate::context::testing::get_shared_context().await).clone(); let mut data = (*$crate::context::testing::get_shared_context().await).clone();
let dir = tempfile::tempdir()?; let dir = tempfile::tempdir()?;
data.db = crate::context::connect_db(dir.path()); data.db = $crate::context::connect_db(dir.path());
crate::context::testing::import_songs_and_jackets_from(dir.path()); $crate::context::testing::import_songs_and_jackets_from(dir.path());
let mut ctx = crate::commands::discord::mock::MockContext::new(data); let mut ctx = $crate::commands::discord::mock::MockContext::new(data);
crate::user::User::create_from_context(&ctx)?; let res = $crate::user::User::create_from_context(&ctx);
ctx.handle_error(res).await?;
let res: Result<(), Error> = $f(&mut ctx).await; let res: Result<(), $crate::context::TaggedError> = $f(&mut ctx).await;
res?; ctx.handle_error(res).await?;
ctx.golden(&std::path::PathBuf::from_str($test_path)?)?; ctx.golden(&std::path::PathBuf::from_str($test_path)?)?;
Ok(()) Ok(())

21
src/lib.rs Normal file
View file

@ -0,0 +1,21 @@
#![allow(async_fn_in_trait)]
#![feature(iter_map_windows)]
#![feature(let_chains)]
#![feature(array_try_map)]
#![feature(async_closure)]
#![feature(try_blocks)]
#![feature(thread_local)]
#![feature(generic_arg_infer)]
#![feature(iter_collect_into)]
pub mod arcaea;
pub mod assets;
pub mod bitmap;
pub mod commands;
pub mod context;
pub mod levenshtein;
pub mod logs;
pub mod recognition;
pub mod time;
pub mod transform;
pub mod user;

View file

@ -1,126 +0,0 @@
#![warn(clippy::str_to_string)]
#![feature(iter_map_windows)]
#![feature(let_chains)]
#![feature(array_try_map)]
#![feature(async_closure)]
#![feature(try_blocks)]
#![feature(thread_local)]
#![feature(generic_arg_infer)]
#![feature(lazy_cell_consume)]
#![feature(iter_collect_into)]
mod arcaea;
mod assets;
mod bitmap;
mod cli;
mod commands;
mod context;
mod levenshtein;
mod logs;
mod recognition;
mod time;
mod transform;
mod user;
use arcaea::play::generate_missing_scores;
use clap::Parser;
use cli::{Cli, Command};
use context::{Error, UserContext};
use poise::serenity_prelude::{self as serenity};
use std::{env::var, sync::Arc, time::Duration};
// {{{ Error handler
async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) {
match error {
error => {
if let Err(e) = poise::builtins::on_error(error).await {
println!("Error while handling error: {}", e)
}
}
}
}
// }}}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
match cli.command {
Command::Discord {} => {
// {{{ Poise options
let options = poise::FrameworkOptions {
commands: vec![
commands::help(),
commands::score::score(),
commands::stats::stats(),
commands::chart::chart(),
],
prefix_options: poise::PrefixFrameworkOptions {
stripped_dynamic_prefix: Some(|_ctx, message, _user_ctx| {
Box::pin(async {
if message.author.bot || Into::<u64>::into(message.author.id) == 1 {
Ok(None)
} else if message.content.starts_with("!") {
Ok(Some(message.content.split_at(1)))
} else if message.guild_id.is_none() {
if message.content.trim().len() == 0 {
Ok(Some(("", "score magic")))
} else {
Ok(Some(("", &message.content[..])))
}
} else {
Ok(None)
}
})
}),
edit_tracker: Some(Arc::new(poise::EditTracker::for_timespan(
Duration::from_secs(3600),
))),
..Default::default()
},
on_error: |error| Box::pin(on_error(error)),
..Default::default()
};
// }}}
// {{{ Start poise
let framework = poise::Framework::builder()
.setup(move |ctx, _ready, framework| {
Box::pin(async move {
println!("Logged in as {}", _ready.user.name);
poise::builtins::register_globally(ctx, &framework.options().commands)
.await?;
let ctx = UserContext::new().await?;
if var("SHIMMERING_REGEN_SCORES").unwrap_or_default() == "1" {
timed!("generate_missing_scores", {
generate_missing_scores(&ctx).await?;
});
}
Ok(ctx)
})
})
.options(options)
.build();
let token = var("SHIMMERING_DISCORD_TOKEN")
.expect("Missing `SHIMMERING_DISCORD_TOKEN` env var");
let intents = serenity::GatewayIntents::non_privileged()
| serenity::GatewayIntents::MESSAGE_CONTENT;
let client = serenity::ClientBuilder::new(token, intents)
.framework(framework)
.await;
client.unwrap().start().await.unwrap()
// }}}
}
Command::PrepareJackets {} => {
cli::prepare_jackets::run().expect("Could not prepare jackets");
}
Command::Analyse(args) => {
cli::analyse::run(args)
.await
.expect("Could not analyse screenshot");
}
}
}

View file

@ -45,7 +45,7 @@ pub fn guess_song_and_chart<'a>(
.or_else(|| strip_case_insensitive_suffix(name, "[ETR]").zip(Some(Difficulty::ETR))) .or_else(|| strip_case_insensitive_suffix(name, "[ETR]").zip(Some(Difficulty::ETR)))
.or_else(|| strip_case_insensitive_suffix(name, "BYD").zip(Some(Difficulty::BYD))) .or_else(|| strip_case_insensitive_suffix(name, "BYD").zip(Some(Difficulty::BYD)))
.or_else(|| strip_case_insensitive_suffix(name, "[BYD]").zip(Some(Difficulty::BYD))) .or_else(|| strip_case_insensitive_suffix(name, "[BYD]").zip(Some(Difficulty::BYD)))
.unwrap_or((&name, Difficulty::FTR)); .unwrap_or((name, Difficulty::FTR));
guess_chart_name(name, &ctx.song_cache, Some(difficulty), true) guess_chart_name(name, &ctx.song_cache, Some(difficulty), true)
} }
@ -85,7 +85,7 @@ pub fn guess_chart_name<'a>(
distance_vec.clear(); distance_vec.clear();
// Apply raw distance // Apply raw distance
let base_distance = edit_distance_with(&text, &song_title, &mut levenshtein_vec); let base_distance = edit_distance_with(text, song_title, &mut levenshtein_vec);
if base_distance <= song.title.len() / 3 { if base_distance <= song.title.len() / 3 {
distance_vec.push(base_distance * 10 + 2); distance_vec.push(base_distance * 10 + 2);
} }
@ -95,7 +95,7 @@ pub fn guess_chart_name<'a>(
if let Some(sliced) = &song_title.get(..shortest_len) if let Some(sliced) = &song_title.get(..shortest_len)
&& (text.len() >= 6 || unsafe_heuristics) && (text.len() >= 6 || unsafe_heuristics)
{ {
let slice_distance = edit_distance_with(&text, sliced, &mut levenshtein_vec); let slice_distance = edit_distance_with(text, sliced, &mut levenshtein_vec);
if slice_distance == 0 { if slice_distance == 0 {
distance_vec.push(3); distance_vec.push(3);
} }
@ -105,7 +105,7 @@ pub fn guess_chart_name<'a>(
if let Some(shorthand) = &chart.shorthand if let Some(shorthand) = &chart.shorthand
&& unsafe_heuristics && unsafe_heuristics
{ {
let short_distance = edit_distance_with(&text, shorthand, &mut levenshtein_vec); let short_distance = edit_distance_with(text, shorthand, &mut levenshtein_vec);
if short_distance <= shorthand.len() / 3 { if short_distance <= shorthand.len() / 3 {
distance_vec.push(short_distance * 10 + 1); distance_vec.push(short_distance * 10 + 1);
} }
@ -121,7 +121,7 @@ pub fn guess_chart_name<'a>(
close_enough.sort_by_key(|(song, _, _)| song.id); close_enough.sort_by_key(|(song, _, _)| song.id);
close_enough.dedup_by_key(|(song, _, _)| song.id); close_enough.dedup_by_key(|(song, _, _)| song.id);
if close_enough.len() == 0 { if close_enough.is_empty() {
if text.len() <= 1 { if text.len() <= 1 {
bail!( bail!(
"Could not find match for chart name '{}' [{:?}]", "Could not find match for chart name '{}' [{:?}]",
@ -133,13 +133,11 @@ pub fn guess_chart_name<'a>(
} }
} else if close_enough.len() == 1 { } else if close_enough.len() == 1 {
break (close_enough[0].0, close_enough[0].1); break (close_enough[0].0, close_enough[0].1);
} else if unsafe_heuristics {
close_enough.sort_by_key(|(_, _, distance)| *distance);
break (close_enough[0].0, close_enough[0].1);
} else { } else {
if unsafe_heuristics { bail!("Name '{}' is too vague to choose a match", raw_text);
close_enough.sort_by_key(|(_, _, distance)| *distance);
break (close_enough[0].0, close_enough[0].1);
} else {
bail!("Name '{}' is too vague to choose a match", raw_text);
};
}; };
}; };

View file

@ -6,7 +6,8 @@ use hypertesseract::{PageSegMode, Tesseract};
use image::imageops::FilterType; use image::imageops::FilterType;
use image::{DynamicImage, GenericImageView}; use image::{DynamicImage, GenericImageView};
use num::integer::Roots; use num::integer::Roots;
use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage}; use poise::serenity_prelude::{CreateAttachment, CreateEmbed};
use poise::CreateReply;
use crate::arcaea::chart::{Chart, Difficulty, Song, DIFFICULTY_MENU_PIXEL_COLORS}; use crate::arcaea::chart::{Chart, Difficulty, Song, DIFFICULTY_MENU_PIXEL_COLORS};
use crate::arcaea::jacket::IMAGE_VEC_DIM; use crate::arcaea::jacket::IMAGE_VEC_DIM;
@ -114,13 +115,16 @@ impl ImageAnalyzer {
"An error occurred, around the time I was extracting data for {ui_rect:?}" "An error occurred, around the time I was extracting data for {ui_rect:?}"
)); ));
let msg = CreateMessage::default().embed(embed); ctx.send(
ctx.send_files([error_attachement], msg).await?; CreateReply::default()
.embed(embed)
.attachment(error_attachement),
)
.await?;
} else { } else {
embed = embed.title("An error occurred"); embed = embed.title("An error occurred");
let msg = CreateMessage::default().embed(embed); ctx.send(CreateReply::default().embed(embed)).await?;
ctx.send_files([], msg).await?;
} }
Ok(()) Ok(())
@ -355,9 +359,9 @@ impl ImageAnalyzer {
} }
// }}} // }}}
// {{{ Read max recall // {{{ Read max recall
pub fn read_max_recall<'a>( pub fn read_max_recall(
&mut self, &mut self,
ctx: &'a UserContext, ctx: &UserContext,
image: &DynamicImage, image: &DynamicImage,
) -> Result<u32, Error> { ) -> Result<u32, Error> {
let image = self.interp_crop(ctx, image, ScoreScreen(ScoreScreenRect::MaxRecall))?; let image = self.interp_crop(ctx, image, ScoreScreen(ScoreScreenRect::MaxRecall))?;

View file

@ -2,7 +2,7 @@ use anyhow::anyhow;
use rusqlite::Row; use rusqlite::Row;
use crate::commands::discord::MessageContext; use crate::commands::discord::MessageContext;
use crate::context::{Error, UserContext}; use crate::context::{ErrorKind, TagError, TaggedError, UserContext};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct User { pub struct User {
@ -13,7 +13,7 @@ pub struct User {
impl User { impl User {
#[inline] #[inline]
fn from_row<'a, 'b>(row: &'a Row<'b>) -> Result<Self, rusqlite::Error> { fn from_row(row: &Row<'_>) -> Result<Self, rusqlite::Error> {
Ok(Self { Ok(Self {
id: row.get("id")?, id: row.get("id")?,
discord_id: row.get("discord_id")?, discord_id: row.get("discord_id")?,
@ -21,7 +21,7 @@ impl User {
}) })
} }
pub fn create_from_context(ctx: &impl MessageContext) -> Result<Self, Error> { pub fn create_from_context(ctx: &impl MessageContext) -> Result<Self, TaggedError> {
let discord_id = ctx.author_id().to_string(); let discord_id = ctx.author_id().to_string();
let user_id: u32 = ctx let user_id: u32 = ctx
.data() .data()
@ -35,7 +35,7 @@ impl User {
)? )?
.query_map([&discord_id], |row| row.get("id"))? .query_map([&discord_id], |row| row.get("id"))?
.next() .next()
.ok_or_else(|| anyhow!("Failed to create user"))??; .ok_or_else(|| anyhow!("No id returned from user creation"))??;
Ok(Self { Ok(Self {
discord_id, discord_id,
@ -44,7 +44,7 @@ impl User {
}) })
} }
pub fn from_context(ctx: &impl MessageContext) -> Result<Self, Error> { pub fn from_context(ctx: &impl MessageContext) -> Result<Self, TaggedError> {
let id = ctx.author_id(); let id = ctx.author_id();
let user = ctx let user = ctx
.data() .data()
@ -53,20 +53,35 @@ impl User {
.prepare_cached("SELECT * FROM users WHERE discord_id = ?")? .prepare_cached("SELECT * FROM users WHERE discord_id = ?")?
.query_map([id], Self::from_row)? .query_map([id], Self::from_row)?
.next() .next()
.ok_or_else(|| anyhow!("You are not an user in my database, sowwy ^~^"))??; .ok_or_else(|| {
anyhow!("You are not an user in my database, sowwy ^~^").tag(ErrorKind::User)
})??;
Ok(user) Ok(user)
} }
pub fn by_id(ctx: &UserContext, id: u32) -> Result<Self, Error> { pub fn by_id(ctx: &UserContext, id: u32) -> Result<Self, TaggedError> {
let user = ctx let user = ctx
.db .db
.get()? .get()?
.prepare_cached("SELECT * FROM users WHERE id = ?")? .prepare_cached("SELECT * FROM users WHERE id = ?")?
.query_map([id], Self::from_row)? .query_map([id], Self::from_row)?
.next() .next()
.ok_or_else(|| anyhow!("You are not an user in my database, sowwy ^~^"))??; .ok_or_else(|| {
anyhow!("You are not an user in my database, sowwy ^~^").tag(ErrorKind::User)
})??;
Ok(user) Ok(user)
} }
#[inline]
pub fn assert_is_pookie(&self) -> Result<(), TaggedError> {
if !self.is_pookie {
return Err(
anyhow!("This feature is reserved for my pookies. Sowwy :3").tag(ErrorKind::User)
);
}
Ok(())
}
} }