1
Fork 0

Jacket recognition

Signed-off-by: prescientmoon <git@moonythm.dev>
This commit is contained in:
prescientmoon 2024-06-22 23:07:11 +02:00
parent a0e3decd7a
commit 5cfeff4e14
Signed by: prescientmoon
SSH key fingerprint: SHA256:UUF9JT2s8Xfyv76b8ZuVL7XrmimH4o49p4b+iexbVH4
14 changed files with 1979 additions and 230 deletions

3
.gitignore vendored
View file

@ -1,4 +1,5 @@
target target
.direnv .direnv
.envrc .envrc
data data/db.sqlite
data/jackets

24
Cargo.lock generated
View file

@ -1292,6 +1292,18 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "kd-tree"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f89ee4e60e82cf7024e5e94618c646fbf61ce7501dc5898b3d12786442d3682"
dependencies = [
"num-traits",
"ordered-float",
"paste",
"typenum",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.4.0" version = "1.4.0"
@ -1659,6 +1671,15 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "ordered-float"
version = "4.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "parking_lot" name = "parking_lot"
version = "0.12.3" version = "0.12.3"
@ -2399,8 +2420,10 @@ name = "shimmeringmoon"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"csv",
"edit-distance", "edit-distance",
"image", "image",
"kd-tree",
"num", "num",
"plotlib", "plotlib",
"poise", "poise",
@ -2408,6 +2431,7 @@ dependencies = [
"sqlx", "sqlx",
"tesseract", "tesseract",
"tokio", "tokio",
"typenum",
] ]
[[package]] [[package]]

View file

@ -5,8 +5,10 @@ edition = "2021"
[dependencies] [dependencies]
chrono = "0.4.38" chrono = "0.4.38"
csv = "1.3.0"
edit-distance = "2.1.0" edit-distance = "2.1.0"
image = "0.25.1" image = "0.25.1"
kd-tree = "0.6.0"
num = "0.4.3" num = "0.4.3"
plotlib = "0.5.1" plotlib = "0.5.1"
poise = "0.6.1" poise = "0.6.1"
@ -14,6 +16,7 @@ prettytable-rs = "0.10.0"
sqlx = { version = "0.7.4", features = ["sqlite", "runtime-tokio", "chrono"] } sqlx = { version = "0.7.4", features = ["sqlite", "runtime-tokio", "chrono"] }
tesseract = "0.15.1" tesseract = "0.15.1"
tokio = {version="1.38.0", features=["rt-multi-thread"]} tokio = {version="1.38.0", features=["rt-multi-thread"]}
typenum = "1.17.0"
[profile.dev.package.sqlx-macros] [profile.dev.package.sqlx-macros]
opt-level = 3 opt-level = 3

1241
data/charts.csv Normal file

File diff suppressed because it is too large Load diff

4
data/jackets.csv Normal file
View file

@ -0,0 +1,4 @@
filename,song_id
grievous-lady,7
einherjar-joker,14
einherjar-joker-byd,14
1 filename song_id
2 grievous-lady 7
3 einherjar-joker 14
4 einherjar-joker-byd 14

View file

@ -27,6 +27,7 @@
]) ])
rust-analyzer-nightly rust-analyzer-nightly
ruff ruff
imagemagick
clang clang
llvmPackages.clang llvmPackages.clang

View file

@ -1,13 +1,13 @@
# {{{ users # {{{ users
create table IF NOT EXISTS users ( create table IF NOT EXISTS users (
id INTEGER PRIMARY KEY, id INTEGER NOT NULL PRIMARY KEY,
discord_id TEXT UNIQUE NOT NULL, discord_id TEXT UNIQUE NOT NULL,
nickname TEXT UNIQUE nickname TEXT UNIQUE
); );
# }}} # }}}
# {{{ songs # {{{ songs
CREATE TABLE IF NOT EXISTS songs ( CREATE TABLE IF NOT EXISTS songs (
id INTEGER PRIMARY KEY, id INTEGER NOT NULL PRIMARY KEY,
title TEXT NOT NULL, title TEXT NOT NULL,
ocr_alias TEXT, ocr_alias TEXT,
artist TEXT, artist TEXT,
@ -17,7 +17,7 @@ CREATE TABLE IF NOT EXISTS songs (
# }}} # }}}
# {{{ charts # {{{ charts
CREATE TABLE IF NOT EXISTS charts ( CREATE TABLE IF NOT EXISTS charts (
id INTEGER PRIMARY KEY, id INTEGER NOT NULL PRIMARY KEY,
song_id INTEGER NOT NULL, song_id INTEGER NOT NULL,
difficulty TEXT NOT NULL CHECK (difficulty IN ('PST','PRS','FTR','ETR','BYD')), difficulty TEXT NOT NULL CHECK (difficulty IN ('PST','PRS','FTR','ETR','BYD')),
@ -32,7 +32,7 @@ CREATE TABLE IF NOT EXISTS charts (
# }}} # }}}
# {{{ plays # {{{ plays
CREATE TABLE IF NOT EXISTS plays ( CREATE TABLE IF NOT EXISTS plays (
id INTEGER PRIMARY KEY, id INTEGER NOT NULL PRIMARY KEY,
chart_id INTEGER NOT NULL, chart_id INTEGER NOT NULL,
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
discord_attachment_id TEXT, discord_attachment_id TEXT,

View file

@ -1,7 +1,8 @@
use sqlx::prelude::FromRow; use sqlx::{prelude::FromRow, SqlitePool};
use crate::context::{Error, UserContext}; use crate::context::Error;
// {{{ Difficuly
#[derive(Debug, Clone, Copy, sqlx::Type)] #[derive(Debug, Clone, Copy, sqlx::Type)]
pub enum Difficulty { pub enum Difficulty {
PST, PST,
@ -12,12 +13,32 @@ pub enum Difficulty {
} }
impl Difficulty { impl Difficulty {
pub const DIFFICULTIES: [Difficulty; 5] =
[Self::PST, Self::PRS, Self::FTR, Self::ETR, Self::BYD];
pub const DIFFICULTY_STRINGS: [&'static str; 5] = ["PST", "PRS", "FTR", "ETR", "BYD"];
#[inline] #[inline]
pub fn to_index(self) -> usize { pub fn to_index(self) -> usize {
self as usize self as usize
} }
} }
impl TryFrom<String> for Difficulty {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
for (i, s) in Self::DIFFICULTY_STRINGS.iter().enumerate() {
if value == **s {
return Ok(Self::DIFFICULTIES[i]);
}
}
Err(format!("Cannot convert {} to difficulty", value))
}
}
// }}}
// {{{ Song
#[derive(Debug, Clone, FromRow)] #[derive(Debug, Clone, FromRow)]
pub struct Song { pub struct Song {
pub id: u32, pub id: u32,
@ -25,29 +46,39 @@ pub struct Song {
pub ocr_alias: Option<String>, pub ocr_alias: Option<String>,
pub artist: Option<String>, pub artist: Option<String>,
} }
// }}}
#[derive(Debug, Clone, Copy, FromRow)] // {{{ Chart
#[derive(Debug, Clone, FromRow)]
pub struct Chart { pub struct Chart {
pub id: u32, pub id: u32,
pub song_id: u32, pub song_id: u32,
pub difficulty: Difficulty, pub difficulty: Difficulty,
pub level: u32, pub level: String, // TODO: this could become an enum
pub note_count: u32, pub note_count: u32,
pub chart_constant: u32, pub chart_constant: u32,
} }
// }}}
// {{{ Cache
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CachedSong { pub struct CachedSong {
song: Song, pub song: Song,
charts: [Option<Chart>; 5], charts: [Option<Chart>; 5],
} }
impl CachedSong { impl CachedSong {
#[inline]
pub fn new(song: Song, charts: [Option<Chart>; 5]) -> Self { pub fn new(song: Song, charts: [Option<Chart>; 5]) -> Self {
Self { song, charts } Self { song, charts }
} }
#[inline]
pub fn lookup(&self, difficulty: Difficulty) -> Option<&Chart> {
self.charts
.get(difficulty.to_index())
.and_then(|c| c.as_ref())
}
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -56,28 +87,48 @@ pub struct SongCache {
} }
impl SongCache { impl SongCache {
pub async fn new(ctx: &UserContext) -> Result<Self, Error> { #[inline]
pub fn lookup(&self, id: u32) -> Option<&CachedSong> {
self.songs.get(id as usize).and_then(|i| i.as_ref())
}
// {{{ Populate cache
pub async fn new(pool: &SqlitePool) -> Result<Self, Error> {
let mut result = Self::default(); let mut result = Self::default();
let songs: Vec<Song> = sqlx::query_as("SELECT * FROM songs") let songs = sqlx::query!("SELECT * FROM songs").fetch_all(pool).await?;
.fetch_all(&ctx.db)
.await?;
for song in songs { for song in songs {
let song = Song {
id: song.id as u32,
title: song.title,
ocr_alias: song.ocr_alias,
artist: song.artist,
};
let song_id = song.id as usize; let song_id = song.id as usize;
if song_id >= result.songs.len() { if song_id >= result.songs.len() {
result.songs.resize(song_id, None); result.songs.resize(song_id + 1, None);
} }
let charts: Vec<Chart> = sqlx::query_as("SELECT * FROM charts WHERE song_id=?") let charts = sqlx::query!("SELECT * FROM charts WHERE song_id=?", song.id)
.bind(song.id) .fetch_all(pool)
.fetch_all(&ctx.db)
.await?; .await?;
let mut chart_cache = [None; 5]; let mut chart_cache: [Option<_>; 5] = Default::default();
for chart in charts { for chart in charts {
chart_cache[chart.difficulty.to_index()] = Some(chart); let chart = Chart {
id: chart.id as u32,
song_id: chart.song_id as u32,
difficulty: Difficulty::try_from(chart.difficulty)?,
level: chart.level,
chart_constant: chart.chart_constant as u32,
note_count: chart.note_count as u32,
};
let index = chart.difficulty.to_index();
chart_cache[index] = Some(chart);
} }
result.songs[song_id] = Some(CachedSong::new(song, chart_cache)); result.songs[song_id] = Some(CachedSong::new(song, chart_cache));
@ -85,4 +136,6 @@ impl SongCache {
Ok(result) Ok(result)
} }
// }}}
} }
// }}}

View file

@ -1,13 +1,11 @@
use std::fmt::Display;
use crate::context::{Context, Error}; use crate::context::{Context, Error};
use crate::score::ImageCropper; use crate::score::{CreatePlay, ImageCropper};
use crate::user::User; use crate::user::User;
use image::imageops::FilterType; use image::imageops::FilterType;
use poise::serenity_prelude::{ use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage};
CreateAttachment, CreateEmbed, CreateEmbedAuthor, CreateMessage, Timestamp,
};
use poise::{serenity_prelude as serenity, CreateReply}; use poise::{serenity_prelude as serenity, CreateReply};
use prettytable::format::{FormatBuilder, LinePosition, LineSeparator};
use prettytable::{row, Table};
/// Show this help menu /// Show this help menu
#[poise::command(prefix_command, track_edits, slash_command)] #[poise::command(prefix_command, track_edits, slash_command)]
@ -40,19 +38,52 @@ pub async fn score(_ctx: Context<'_>) -> Result<(), Error> {
Ok(()) Ok(())
} }
// {{{ Send error embed with image
async fn error_with_image(
ctx: Context<'_>,
bytes: &[u8],
filename: &str,
message: &str,
err: impl Display,
) -> Result<(), Error> {
let error_attachement = CreateAttachment::bytes(bytes, filename);
let msg = CreateMessage::default().embed(
CreateEmbed::default()
.title(message)
.attachment(filename)
.description(format!("{}", err)),
);
ctx.channel_id()
.send_files(ctx.http(), [error_attachement], msg)
.await?;
Ok(())
}
// }}}
/// Identify scores from attached images. /// Identify scores from attached images.
#[poise::command(prefix_command, slash_command)] #[poise::command(prefix_command, slash_command)]
pub async fn magic( pub async fn magic(
ctx: Context<'_>, ctx: Context<'_>,
#[description = "Images containing scores"] files: Vec<serenity::Attachment>, #[description = "Images containing scores"] files: Vec<serenity::Attachment>,
) -> Result<(), Error> { ) -> Result<(), Error> {
println!("{:?}", User::from_context(&ctx).await); let user = match User::from_context(&ctx).await {
Ok(user) => user,
Err(_) => {
ctx.say("You are not an user in my database, sorry!")
.await?;
return Ok(());
}
};
if files.len() == 0 { if files.len() == 0 {
ctx.reply("No images found attached to message").await?; ctx.reply("No images found attached to message").await?;
} else { } else {
let mut embeds: Vec<CreateEmbed> = vec![];
let mut attachements: Vec<CreateAttachment> = vec![];
let handle = ctx let handle = ctx
.reply(format!("Processing: 0/{} images", files.len())) .reply(format!("Processed 0/{} scores", files.len()))
.await?; .await?;
for (i, file) in files.iter().enumerate() { for (i, file) in files.iter().enumerate() {
@ -62,76 +93,110 @@ pub async fn magic(
let format = image::guess_format(&bytes)?; let format = image::guess_format(&bytes)?;
// Image pre-processing // Image pre-processing
let mut image = image::load_from_memory_with_format(&bytes, format)? let image = image::load_from_memory_with_format(&bytes, format)?.resize(
.resize(1024, 1024, FilterType::Nearest) 1024,
.grayscale() 1024,
.blur(1.); FilterType::Nearest,
image.invert();
// {{{ Table experiment
let table_format = FormatBuilder::new()
.separators(
&[LinePosition::Title],
LineSeparator::new('─', '┬', '┌', '┐'),
)
.padding(1, 1)
.build();
let mut table = Table::new();
table.set_format(table_format);
table.set_titles(row!["Chart", "Level", "Score", "Rating"]);
table.add_row(row!["Quon", "BYD 10", "10000807", "12.3 (-132)"]);
table.add_row(row!["Monochrome princess", "FTR 9+", " 9380807", "10.2"]);
table.add_row(row!["Grievous lady", "FTR 11", " 9286787", "11.2"]);
table.add_row(row!["Fracture ray", "FTR 11", " 8990891", "11.0"]);
table.add_row(row!["Shades of Light", "FTR 9+", "10000976", " 9.3 (-13)"]);
ctx.say(format!("```\n{}\n```", table.to_string())).await?;
// }}}
let icon_attachement = CreateAttachment::file(
&tokio::fs::File::open("./data/jackets/grievous.png").await?,
"grievous.png",
)
.await?;
let msg = CreateMessage::default().embed(
CreateEmbed::default()
.title("Grievous lady [FTR 11]")
.thumbnail("attachment://grievous.png")
.field("Score", "998302 (+8973)", true)
.field("Rating", "12.2 (+.6)", true)
.field("Grade", "EX+", true)
.field("ζ-Score", "982108 (+347)", true)
.field("ζ-Rating", "11.5 (+.45)", true)
.field("ζ-Grade", "EX", true)
.field("Status", "FR (-243F)", true)
.field("Max recall", "308/1073", true)
.field("Breakdown", "894/342/243/23", true),
); );
ctx.channel_id() // // {{{ Table experiment
.send_files(ctx.http(), [icon_attachement], msg) // let table_format = FormatBuilder::new()
.await?; // .separators(
// &[LinePosition::Title],
// LineSeparator::new('─', '┬', '┌', '┐'),
// )
// .padding(1, 1)
// .build();
// let mut table = Table::new();
// table.set_format(table_format);
// table.set_titles(row!["Chart", "Level", "Score", "Rating"]);
// table.add_row(row!["Quon", "BYD 10", "10000807", "12.3 (-132)"]);
// table.add_row(row!["Monochrome princess", "FTR 9+", " 9380807", "10.2"]);
// table.add_row(row!["Grievous lady", "FTR 11", " 9286787", "11.2"]);
// table.add_row(row!["Fracture ray", "FTR 11", " 8990891", "11.0"]);
// table.add_row(row!["Shades of Light", "FTR 9+", "10000976", " 9.3 (-13)"]);
// ctx.say(format!("```\n{}\n```", table.to_string())).await?;
// // }}}
// // {{{ Embed experiment
// let icon_attachement = CreateAttachment::file(
// &tokio::fs::File::open("./data/jackets/grievous.png").await?,
// "grievous.png",
// )
// .await?;
// let msg = CreateMessage::default().embed(
// CreateEmbed::default()
// .title("Grievous lady [FTR 11]")
// .thumbnail("attachment://grievous.png")
// .field("Score", "998302 (+8973)", true)
// .field("Rating", "12.2 (+.6)", true)
// .field("Grade", "EX+", true)
// .field("ζ-Score", "982108 (+347)", true)
// .field("ζ-Rating", "11.5 (+.45)", true)
// .field("ζ-Grade", "EX", true)
// .field("Status", "FR (-243F)", true)
// .field("Max recall", "308/1073", true)
// .field("Breakdown", "894/342/243/23", true),
// );
//
// ctx.channel_id()
// .send_files(ctx.http(), [icon_attachement], msg)
// .await?;
// // }}}
// Create cropper and run OCR // Create cropper and run OCR
let mut cropper = ImageCropper::default(); let mut cropper = ImageCropper::default();
let score_readout = match cropper.read_score(&image) {
let (jacket, cached_song) = match cropper.read_jacket(ctx.data(), &image) {
// {{{ Jacket recognition error handling
Err(err) => {
error_with_image(
ctx,
&cropper.bytes,
&file.filename,
"Error while detecting jacket",
err,
)
.await?;
continue;
}
// }}}
Ok(j) => j,
};
let mut image = image.grayscale().blur(1.);
let difficulty = match cropper.read_difficulty(&image) {
// {{{ OCR error handling // {{{ OCR error handling
Err(err) => { Err(err) => {
let error_attachement = error_with_image(
CreateAttachment::bytes(cropper.bytes, &file.filename); ctx,
let msg = CreateMessage::default().embed( &cropper.bytes,
CreateEmbed::default() &file.filename,
.title("Could not read score from picture") "Could not read score from picture",
.attachment(&file.filename) &err,
.description(format!("{}", err)) )
.author( .await?;
CreateEmbedAuthor::new(&ctx.author().name)
.icon_url(ctx.author().face()), continue;
) }
.timestamp(Timestamp::now()), // }}}
); Ok(d) => d,
ctx.channel_id() };
.send_files(ctx.http(), [error_attachement], msg)
.await?; image.invert();
let score = match cropper.read_score(&image) {
// {{{ OCR error handling
Err(err) => {
error_with_image(
ctx,
&cropper.bytes,
&file.filename,
"Could not read score from picture",
&err,
)
.await?;
continue; continue;
} }
@ -139,31 +204,44 @@ pub async fn magic(
Ok(score) => score, Ok(score) => score,
}; };
// Reply with attachement & readout let song = &cached_song.song;
let attachement = CreateAttachment::bytes(cropper.bytes, &file.filename); let chart = cached_song.lookup(difficulty).ok_or_else(|| {
let reply = CreateReply::default() format!(
.attachment(attachement) "Could not find difficulty {:?} for song {}",
.content(format!("Score: {:?}", score_readout)) difficulty, song.title
.reply(true); )
ctx.send(reply).await?; })?;
// Edit progress reply let play = CreatePlay::new(score, chart, &user)
let progress_reply = CreateReply::default() .with_attachment(file)
.content(format!("Processing: {}/{} images", i + 1, files.len())) .save(&ctx.data())
.reply(true); .await?;
handle.edit(ctx, progress_reply).await?;
let (embed, attachement) = play.to_embed(&song, &chart, &jacket).await?;
embeds.push(embed);
attachements.push(attachement);
} else { } else {
ctx.reply("One of the attached files is not an image!") ctx.reply("One of the attached files is not an image!")
.await?; .await?;
continue; continue;
} }
let edited = CreateReply::default().reply(true).content(format!(
"Processed {}/{} scores",
i + 1,
files.len()
));
handle.edit(ctx, edited).await?;
} }
// Finish off progress reply handle.delete(ctx).await?;
let progress_reply = CreateReply::default()
.content(format!("All images have been processed!")) let msg = CreateMessage::new().embeds(embeds);
.reply(true);
handle.edit(ctx, progress_reply).await?; ctx.channel_id()
.send_files(ctx.http(), attachements, msg)
.await?;
} }
Ok(()) Ok(())

View file

@ -1,6 +1,8 @@
use std::path::PathBuf;
use sqlx::SqlitePool; use sqlx::SqlitePool;
use crate::chart::SongCache; use crate::{chart::SongCache, jacket::JacketCache};
// Types used by all command functions // Types used by all command functions
pub type Error = Box<dyn std::error::Error + Send + Sync>; pub type Error = Box<dyn std::error::Error + Send + Sync>;
@ -8,16 +10,22 @@ pub type Context<'a> = poise::Context<'a, UserContext, Error>;
// Custom user data passed to all command functions // Custom user data passed to all command functions
pub struct UserContext { pub struct UserContext {
pub data_dir: PathBuf,
pub db: SqlitePool, pub db: SqlitePool,
pub song_cache: SongCache, pub song_cache: SongCache,
pub jacket_cache: JacketCache,
} }
impl UserContext { impl UserContext {
#[inline] #[inline]
pub fn new(db: SqlitePool) -> Self { pub async fn new(data_dir: PathBuf, db: SqlitePool) -> Result<Self, Error> {
Self { let song_cache = SongCache::new(&db).await?;
let jacket_cache = JacketCache::new(&data_dir)?;
Ok(Self {
data_dir,
db, db,
song_cache: SongCache::default(), song_cache,
} jacket_cache,
})
} }
} }

122
src/jacket.rs Normal file
View file

@ -0,0 +1,122 @@
use std::path::PathBuf;
use image::{GenericImageView, Rgba};
use kd_tree::{KdMap, KdPoint};
use num::Integer;
use crate::context::Error;
/// How many sub-segments to split each side into
const SPLIT_FACTOR: u32 = 5;
const IMAGE_VEC_DIM: usize = (SPLIT_FACTOR * SPLIT_FACTOR * 3) as usize;
#[derive(Debug, Clone)]
pub struct ImageVec {
pub colors: [f32; IMAGE_VEC_DIM],
}
#[derive(Debug, Clone)]
pub struct Jacket {
pub song_id: u32,
pub path: PathBuf,
}
impl ImageVec {
// {{{ (Image => vector) encoding
fn from_image(image: &impl GenericImageView<Pixel = Rgba<u8>>) -> ImageVec {
let mut colors = [0.0; IMAGE_VEC_DIM];
let chunk_width = image.width() / SPLIT_FACTOR;
let chunk_height = image.height() / SPLIT_FACTOR;
for i in 0..(SPLIT_FACTOR * SPLIT_FACTOR) {
let (iy, ix) = i.div_rem(&SPLIT_FACTOR);
let cropped = image.view(
chunk_width * ix,
chunk_height * iy,
chunk_width,
chunk_height,
);
let mut r = 0;
let mut g = 0;
let mut b = 0;
let mut count = 0;
for (_, _, pixel) in cropped.pixels() {
r += pixel.0[0] as u64;
g += pixel.0[1] as u64;
b += pixel.0[2] as u64;
count += 1;
}
let count = count as f64;
let r = r as f64 / count;
let g = g as f64 / count;
let b = b as f64 / count;
colors[i as usize * 3 + 0] = r as f32;
colors[i as usize * 3 + 1] = g as f32;
colors[i as usize * 3 + 2] = b as f32;
}
Self { colors }
}
// }}}
}
impl KdPoint for ImageVec {
type Dim = typenum::U75;
type Scalar = f32;
fn dim() -> usize {
IMAGE_VEC_DIM
}
fn at(&self, i: usize) -> Self::Scalar {
self.colors[i]
}
}
pub struct JacketCache {
tree: KdMap<ImageVec, Jacket>,
}
impl JacketCache {
// {{{ Generate tree
pub fn new(data_dir: &PathBuf) -> Result<Self, Error> {
let jacket_csv_path = data_dir.join("jackets.csv");
let mut reader = csv::Reader::from_path(jacket_csv_path)?;
let mut entries = vec![];
for record in reader.records() {
let record = record?;
let filename = &record[0];
let song_id = u32::from_str_radix(&record[1], 10)?;
let image_path = data_dir.join(format!("jackets/{}.png", filename));
let image = image::io::Reader::open(&image_path)?.decode()?;
let jacket = Jacket {
song_id,
path: image_path,
};
entries.push((ImageVec::from_image(&image), jacket))
}
let result = Self {
tree: KdMap::build_by_ordered_float(entries),
};
Ok(result)
}
// }}}
// {{{ Recognise
#[inline]
pub fn recognise(
&self,
image: &impl GenericImageView<Pixel = Rgba<u8>>,
) -> Option<(f32, &Jacket)> {
self.tree
.nearest(&ImageVec::from_image(image))
.map(|p| (p.squared_distance.sqrt(), &p.item.1))
}
// }}}
}

View file

@ -4,15 +4,14 @@
mod chart; mod chart;
mod commands; mod commands;
mod context; mod context;
mod jacket;
mod score; mod score;
mod user; mod user;
use chart::SongCache;
use context::{Error, UserContext}; use context::{Error, UserContext};
use poise::serenity_prelude as serenity; use poise::serenity_prelude as serenity;
use score::score_to_zeta_score;
use sqlx::sqlite::SqlitePoolOptions; use sqlx::sqlite::SqlitePoolOptions;
use std::{env::var, sync::Arc, time::Duration}; use std::{env::var, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
// {{{ Error handler // {{{ Error handler
async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) { async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) {
@ -40,9 +39,6 @@ async fn main() {
.await .await
.unwrap(); .unwrap();
println!("{:?}", score_to_zeta_score(9966677, 1303));
println!("{:?}", score_to_zeta_score(9970525, 1303));
// {{{ Poise options // {{{ Poise options
let options = poise::FrameworkOptions { let options = poise::FrameworkOptions {
commands: vec![commands::help(), commands::score()], commands: vec![commands::help(), commands::score()],
@ -64,8 +60,7 @@ async fn main() {
Box::pin(async move { Box::pin(async move {
println!("Logged in as {}", _ready.user.name); println!("Logged in as {}", _ready.user.name);
poise::builtins::register_globally(ctx, &framework.options().commands).await?; poise::builtins::register_globally(ctx, &framework.options().commands).await?;
let mut ctx = UserContext::new(pool); let ctx = UserContext::new(PathBuf::from_str(&data_dir)?, pool).await?;
ctx.song_cache = SongCache::new(&ctx).await?;
Ok(ctx) Ok(ctx)
}) })
}) })

View file

@ -1,13 +1,15 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::{io::Cursor, sync::OnceLock, time::Instant}; use std::{fmt::Display, io::Cursor, sync::OnceLock};
use image::DynamicImage; use image::{DynamicImage, GenericImageView};
use num::Rational64; use num::Rational64;
use poise::serenity_prelude::{Attachment, AttachmentId, CreateAttachment, CreateEmbed};
use tesseract::{PageSegMode, Tesseract}; use tesseract::{PageSegMode, Tesseract};
use crate::{ use crate::{
chart::{Chart, Difficulty}, chart::{CachedSong, Chart, Difficulty, Song},
context::{Error, UserContext}, context::{Error, UserContext},
jacket::Jacket,
user::User, user::User,
}; };
@ -125,9 +127,9 @@ impl RelativeRect {
let p = (aspect_ratio - low_ratio) / (high_ratio - low_ratio); let p = (aspect_ratio - low_ratio) / (high_ratio - low_ratio);
return Some(Self::new( return Some(Self::new(
lerp(p, low.x, high.x), lerp(p, low.x, high.x),
lerp(p, low.y, high.y) - 0.005, lerp(p, low.y, high.y),
lerp(p, low.width, high.width), lerp(p, low.width, high.width),
lerp(p, low.height, high.height) + 2. * 0.005, lerp(p, low.height, high.height),
dimensions, dimensions,
)); ));
} }
@ -138,6 +140,34 @@ impl RelativeRect {
} }
// }}} // }}}
// {{{ Data points // {{{ Data points
fn process_datapoints(rects: &mut Vec<RelativeRect>) {
rects.sort_by_key(|r| (r.dimensions.aspect_ratio() * 1000.0).floor() as u32);
// Filter datapoints that are close together
let mut i = 0;
while i < rects.len() - 1 {
let low = rects[i];
let high = rects[i + 1];
if (low.dimensions.aspect_ratio() - high.dimensions.aspect_ratio()).abs() < 0.001 {
// TODO: we could interpolate here but oh well
rects.remove(i + 1);
}
i += 1;
}
}
fn widen_by(rects: &mut Vec<RelativeRect>, x: f32, y: f32) {
for rect in rects {
rect.x -= x;
rect.y -= y;
rect.width += 2. * x;
rect.height += 2. * y;
}
}
// {{{ Score
fn score_rects() -> &'static [RelativeRect] { fn score_rects() -> &'static [RelativeRect] {
static CELL: OnceLock<Vec<RelativeRect>> = OnceLock::new(); static CELL: OnceLock<Vec<RelativeRect>> = OnceLock::new();
CELL.get_or_init(|| { CELL.get_or_init(|| {
@ -152,74 +182,138 @@ fn score_rects() -> &'static [RelativeRect] {
AbsoluteRect::new(1069, 868, 636, 112, ImageDimensions::new(2732, 2048)).to_relative(), AbsoluteRect::new(1069, 868, 636, 112, ImageDimensions::new(2732, 2048)).to_relative(),
AbsoluteRect::new(1125, 510, 534, 93, ImageDimensions::new(2778, 1284)).to_relative(), AbsoluteRect::new(1125, 510, 534, 93, ImageDimensions::new(2778, 1284)).to_relative(),
]; ];
rects.sort_by_key(|r| (r.dimensions.aspect_ratio() * 1000.0).floor() as u32); process_datapoints(&mut rects);
widen_by(&mut rects, 0.0, 0.005);
// Filter datapoints that are close together
let mut i = 0;
while i < rects.len() - 1 {
let low = rects[i];
let high = rects[i + 1];
if (low.dimensions.aspect_ratio() - high.dimensions.aspect_ratio()).abs() < 0.001 {
// TODO: we could interpolate here but oh well
rects.remove(i + 1);
}
i += 1;
}
rects
})
}
fn difficulty_rects() -> &'static [RelativeRect] {
static CELL: OnceLock<Vec<RelativeRect>> = OnceLock::new();
CELL.get_or_init(|| {
let mut rects: Vec<RelativeRect> = vec![
AbsoluteRect::new(642, 287, 284, 51, ImageDimensions::new(1560, 720)).to_relative(),
AbsoluteRect::new(651, 285, 305, 55, ImageDimensions::new(1600, 720)).to_relative(),
AbsoluteRect::new(748, 485, 503, 82, ImageDimensions::new(2000, 1200)).to_relative(),
AbsoluteRect::new(841, 683, 500, 92, ImageDimensions::new(2160, 1620)).to_relative(),
AbsoluteRect::new(851, 707, 532, 91, ImageDimensions::new(2224, 1668)).to_relative(),
AbsoluteRect::new(1037, 462, 476, 89, ImageDimensions::new(2532, 1170)).to_relative(),
AbsoluteRect::new(973, 653, 620, 105, ImageDimensions::new(2560, 1600)).to_relative(),
AbsoluteRect::new(1069, 868, 636, 112, ImageDimensions::new(2732, 2048)).to_relative(),
AbsoluteRect::new(1125, 510, 534, 93, ImageDimensions::new(2778, 1284)).to_relative(),
];
rects.sort_by_key(|r| (r.dimensions.aspect_ratio() * 1000.0).floor() as u32);
rects rects
}) })
} }
// }}} // }}}
// {{{ Plays // {{{ Difficulty
/// Returns the zeta score and the number of shinies fn difficulty_rects() -> &'static [RelativeRect] {
pub fn score_to_zeta_score(score: u32, note_count: u32) -> (u32, u32) { static CELL: OnceLock<Vec<RelativeRect>> = OnceLock::new();
// Smallest possible difference between (zeta-)scores CELL.get_or_init(|| {
let increment = Rational64::new_raw(5000000, note_count as i64).reduced(); let mut rects: Vec<RelativeRect> = vec![
let zeta_increment = Rational64::new_raw(2000000, note_count as i64).reduced(); AbsoluteRect::new(232, 203, 104, 23, ImageDimensions::new(1560, 720)).to_relative(),
AbsoluteRect::new(252, 204, 99, 21, ImageDimensions::new(1600, 720)).to_relative(),
AbsoluteRect::new(146, 356, 155, 34, ImageDimensions::new(2000, 1200)).to_relative(),
AbsoluteRect::new(155, 546, 167, 38, ImageDimensions::new(2160, 1620)).to_relative(),
AbsoluteRect::new(163, 562, 175, 38, ImageDimensions::new(2224, 1668)).to_relative(),
AbsoluteRect::new(378, 332, 161, 34, ImageDimensions::new(2532, 1170)).to_relative(),
AbsoluteRect::new(183, 487, 197, 44, ImageDimensions::new(2560, 1600)).to_relative(),
AbsoluteRect::new(198, 692, 219, 46, ImageDimensions::new(2732, 2048)).to_relative(),
AbsoluteRect::new(414, 364, 177, 38, ImageDimensions::new(2778, 1284)).to_relative(),
];
process_datapoints(&mut rects);
rects
})
}
// }}}
// {{{ Jacket
fn jacket_rects() -> &'static [RelativeRect] {
static CELL: OnceLock<Vec<RelativeRect>> = OnceLock::new();
CELL.get_or_init(|| {
let mut rects: Vec<RelativeRect> = vec![
AbsoluteRect::new(171, 268, 375, 376, ImageDimensions::new(1560, 720)).to_relative(),
AbsoluteRect::new(190, 267, 376, 377, ImageDimensions::new(1600, 720)).to_relative(),
AbsoluteRect::new(46, 456, 590, 585, ImageDimensions::new(2000, 1200)).to_relative(),
AbsoluteRect::new(51, 655, 633, 632, ImageDimensions::new(2160, 1620)).to_relative(),
AbsoluteRect::new(53, 675, 654, 653, ImageDimensions::new(2224, 1668)).to_relative(),
AbsoluteRect::new(274, 434, 614, 611, ImageDimensions::new(2532, 1170)).to_relative(),
AbsoluteRect::new(58, 617, 753, 750, ImageDimensions::new(2560, 1600)).to_relative(),
AbsoluteRect::new(65, 829, 799, 800, ImageDimensions::new(2732, 2048)).to_relative(),
AbsoluteRect::new(300, 497, 670, 670, ImageDimensions::new(2778, 1284)).to_relative(),
];
process_datapoints(&mut rects);
rects
})
}
// }}}
// }}}
// {{{ Score
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Score(pub u32);
let score = Rational64::from_integer(score as i64); impl Score {
let score_units = (score / increment).floor(); // {{{ Score => ζ-Score
/// Returns the zeta score and the number of shinies
pub fn to_zeta(self, note_count: u32) -> (Score, u32) {
// Smallest possible difference between (zeta-)scores
let increment = Rational64::new_raw(5000000, note_count as i64).reduced();
let zeta_increment = Rational64::new_raw(2000000, note_count as i64).reduced();
let non_shiny_score = (score_units * increment).floor(); let score = Rational64::from_integer(self.0 as i64);
let shinies = score - non_shiny_score; let score_units = (score / increment).floor();
let zeta_score_units = Rational64::from_integer(2) * score_units + shinies; let non_shiny_score = (score_units * increment).floor();
let zeta_score = (zeta_increment * zeta_score_units).floor().to_integer() as u32; let shinies = score - non_shiny_score;
(zeta_score, shinies.to_integer() as u32) let zeta_score_units = Rational64::from_integer(2) * score_units + shinies;
let zeta_score = Score((zeta_increment * zeta_score_units).floor().to_integer() as u32);
(zeta_score, shinies.to_integer() as u32)
}
// }}}
// {{{ Score => Play rating
#[inline]
pub fn play_rating(self, chart_constant: u32) -> i32 {
chart_constant as i32
+ if self.0 >= 10000000 {
200
} else if self.0 >= 9800000 {
100 + (self.0 as i32 - 9_800_000) / 20_000
} else {
(self.0 as i32 - 9_500_000) / 10_000
}
}
// }}}
// {{{ Score => grade
#[inline]
// TODO: Perhaps make an enum for this
pub fn grade(self) -> &'static str {
let score = self.0;
if score > 9900000 {
"EX+"
} else if score > 9800000 {
"EX"
} else if score > 9500000 {
"AA"
} else if score > 9200000 {
"A"
} else if score > 8900000 {
"B"
} else if score > 8600000 {
"C"
} else {
"D"
}
}
// }}}
} }
impl Display for Score {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let score = self.0;
write!(
f,
"{}'{}'{}",
score / 1000000,
(score / 1000) % 1000,
score % 1000
)
}
}
// }}}
// {{{ Plays
// {{{ Create play // {{{ Create play
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CreatePlay { pub struct CreatePlay {
chart_id: u32, chart_id: u32,
user_id: u32, user_id: u32,
discord_attachment_id: Option<String>, discord_attachment_id: Option<AttachmentId>,
// Actual score data // Actual score data
score: u32, score: Score,
zeta_score: Option<u32>, zeta_score: Score,
// Optional score details // Optional score details
max_recall: Option<u32>, max_recall: Option<u32>,
@ -232,13 +326,13 @@ pub struct CreatePlay {
impl CreatePlay { impl CreatePlay {
#[inline] #[inline]
pub fn new(score: u32, chart: Chart, user: User) -> Self { pub fn new(score: Score, chart: &Chart, user: &User) -> Self {
Self { Self {
chart_id: chart.id, chart_id: chart.id,
user_id: user.id, user_id: user.id,
discord_attachment_id: None, discord_attachment_id: None,
score, score,
zeta_score: Some(score_to_zeta_score(score, chart.note_count).0), zeta_score: score.to_zeta(chart.note_count as u32).0,
max_recall: None, max_recall: None,
far_notes: None, far_notes: None,
// TODO: populate these // TODO: populate these
@ -247,52 +341,121 @@ impl CreatePlay {
} }
} }
#[inline]
pub fn with_attachment(mut self, attachment: &Attachment) -> Self {
self.discord_attachment_id = Some(attachment.id);
self
}
// {{{ Save
pub async fn save(self, ctx: &UserContext) -> Result<Play, Error> { pub async fn save(self, ctx: &UserContext) -> Result<Play, Error> {
let play = sqlx::query_as!( let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64);
Play, let play = sqlx::query!(
" "
INSERT INTO plays( INSERT INTO plays(
user_id,chart_id,discord_attachment_id, user_id,chart_id,discord_attachment_id,
score,zeta_score,max_recall,far_notes score,zeta_score,max_recall,far_notes
) )
VALUES(?,?,?,?,?,?,?) VALUES(?,?,?,?,?,?,?)
RETURNING * RETURNING id, created_at
", ",
self.user_id, self.user_id,
self.chart_id, self.chart_id,
self.discord_attachment_id, attachment_id,
self.score, self.score.0,
self.zeta_score, self.zeta_score.0,
self.max_recall, self.max_recall,
self.far_notes self.far_notes
) )
.fetch_one(&ctx.db) .fetch_one(&ctx.db)
.await?; .await?;
Ok(play) Ok(Play {
id: play.id as u32,
created_at: play.created_at,
chart_id: self.chart_id,
user_id: self.user_id,
discord_attachment_id: self.discord_attachment_id,
score: self.score,
zeta_score: self.zeta_score,
max_recall: self.max_recall,
far_notes: self.far_notes,
creation_ptt: self.creation_ptt,
creation_zeta_ptt: self.creation_zeta_ptt,
})
} }
// }}}
} }
// }}} // }}}
// {{{ Play // {{{ Play
#[derive(Debug, Clone, sqlx::FromRow)] #[derive(Debug, Clone, sqlx::FromRow)]
pub struct Play { pub struct Play {
id: i64, id: u32,
chart_id: i64, chart_id: u32,
user_id: i64, user_id: u32,
discord_attachment_id: Option<String>, discord_attachment_id: Option<AttachmentId>,
// Actual score data // Actual score data
score: i64, score: Score,
zeta_score: Option<i64>, zeta_score: Score,
// Optional score details // Optional score details
max_recall: Option<i64>, max_recall: Option<u32>,
far_notes: Option<i64>, far_notes: Option<u32>,
// Creation data // Creation data
created_at: chrono::NaiveDateTime, created_at: chrono::NaiveDateTime,
creation_ptt: Option<i64>, creation_ptt: Option<u32>,
creation_zeta_ptt: Option<i64>, creation_zeta_ptt: Option<u32>,
}
impl Play {
// {{{ Play to embed
pub async fn to_embed(
&self,
song: &Song,
chart: &Chart,
jacket: &Jacket,
) -> Result<(CreateEmbed, CreateAttachment), Error> {
let (_, shiny_count) = self.score.to_zeta(chart.note_count);
let attachement_name = format!("{:?}-{:?}.png", song.id, self.score.0);
let icon_attachement = CreateAttachment::file(
&tokio::fs::File::open(&jacket.path).await?,
&attachement_name,
)
.await?;
let embed = CreateEmbed::default()
.title(&song.title)
.thumbnail(format!("attachment://{}", &attachement_name))
.field("Score", format!("{} (+?)", self.score), true)
.field(
"Rating",
format!(
"{:.2} (+?)",
(self.score.play_rating(chart.chart_constant)) as f32 / 100.
),
true,
)
.field("Grade", self.score.grade(), true)
.field("ζ-Score", format!("{} (+?)", self.zeta_score), true)
.field(
"ζ-Rating",
format!(
"{:.2} (+?)",
(self.zeta_score.play_rating(chart.chart_constant)) as f32 / 100.
),
true,
)
.field("ζ-Grade", self.zeta_score.grade(), true)
.field("Status", "?", true)
.field("Max recall", "?", true)
.field("Breakdown", format!("{}/?/?/?", shiny_count), true);
Ok((embed, icon_attachement))
}
// }}}
} }
// }}} // }}}
// {{{ Tests // {{{ Tests
@ -305,12 +468,13 @@ mod score_tests {
// note counts // note counts
for note_count in 200..=2000 { for note_count in 200..=2000 {
for shiny_count in 0..=note_count { for shiny_count in 0..=note_count {
let score = 10000000 + shiny_count; let score = Score(10000000 + shiny_count);
let zeta_score_units = 4 * (note_count - shiny_count) + 5 * shiny_count; let zeta_score_units = 4 * (note_count - shiny_count) + 5 * shiny_count;
let (zeta_score, computed_shiny_count) = score.to_zeta(note_count);
let expected_zeta_score = Rational64::from_integer(zeta_score_units as i64) let expected_zeta_score = Rational64::from_integer(zeta_score_units as i64)
* Rational64::new_raw(2000000, note_count as i64).reduced(); * Rational64::new_raw(2000000, note_count as i64).reduced();
let (zeta_score, computed_shiny_count) = score_to_zeta_score(score, note_count);
assert_eq!(zeta_score, expected_zeta_score.to_integer() as u32); assert_eq!(zeta_score, Score(expected_zeta_score.to_integer() as u32));
assert_eq!(computed_shiny_count, shiny_count); assert_eq!(computed_shiny_count, shiny_count);
} }
} }
@ -318,19 +482,6 @@ mod score_tests {
} }
// }}} // }}}
// }}} // }}}
// {{{ Ocr types
#[derive(Debug, Clone, Copy)]
pub struct ScoreReadout {
pub score: u32,
pub difficulty: Difficulty,
}
impl ScoreReadout {
pub fn new(score: u32, difficulty: Difficulty) -> Self {
Self { score, difficulty }
}
}
// }}}
// {{{ Run OCR // {{{ Run OCR
/// Caches a byte vector in order to prevent reallocation /// Caches a byte vector in order to prevent reallocation
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -352,18 +503,19 @@ impl ImageCropper {
Ok(()) Ok(())
} }
pub fn read_score(&mut self, image: &DynamicImage) -> Result<ScoreReadout, Error> { // {{{ Read score
let rect = pub fn read_score(&mut self, image: &DynamicImage) -> Result<Score, Error> {
self.crop_image_to_bytes(
&image,
RelativeRect::from_aspect_ratio(ImageDimensions::from_image(image), score_rects()) RelativeRect::from_aspect_ratio(ImageDimensions::from_image(image), score_rects())
.ok_or_else(|| "Could not find score area in picture")? .ok_or_else(|| "Could not find score area in picture")?
.to_absolute(); .to_absolute(),
self.crop_image_to_bytes(&image, rect)?; )?;
let mut t = Tesseract::new(None, Some("eng"))? let mut t = Tesseract::new(None, Some("eng"))?
// .set_variable("classify_bln_numeric_mode", "1'")? // .set_variable("classify_bln_numeric_mode", "1'")?
.set_variable("tessedit_char_whitelist", "0123456789'")? .set_variable("tessedit_char_whitelist", "0123456789'")?
.set_image_from_mem(&self.bytes)?; .set_image_from_mem(&self.bytes)?;
t.set_page_seg_mode(PageSegMode::PsmRawLine); t.set_page_seg_mode(PageSegMode::PsmRawLine);
t = t.recognize()?; t = t.recognize()?;
@ -378,8 +530,72 @@ impl ImageCropper {
.filter(|char| *char != ' ' && *char != '\'') .filter(|char| *char != ' ' && *char != '\'')
.collect(); .collect();
let int = u32::from_str_radix(&text, 10)?; let score = u32::from_str_radix(&text, 10)?;
Ok(ScoreReadout::new(int, Difficulty::FTR)) Ok(Score(score))
} }
// }}}
// {{{ Read difficulty
pub fn read_difficulty(&mut self, image: &DynamicImage) -> Result<Difficulty, Error> {
self.crop_image_to_bytes(
&image,
RelativeRect::from_aspect_ratio(ImageDimensions::from_image(image), difficulty_rects())
.ok_or_else(|| "Could not find difficulty area in picture")?
.to_absolute(),
)?;
let mut t = Tesseract::new(None, Some("eng"))?.set_image_from_mem(&self.bytes)?;
t.set_page_seg_mode(PageSegMode::PsmRawLine);
t = t.recognize()?;
if t.mean_text_conf() < 10 {
Err("Difficulty text is not readable.")?;
}
let text: &str = &t.get_text()?;
let text = text.trim();
let difficulty = Difficulty::DIFFICULTIES
.iter()
.zip(Difficulty::DIFFICULTY_STRINGS)
.min_by_key(|(_, difficulty_string)| {
edit_distance::edit_distance(difficulty_string, text)
})
.map(|(difficulty, _)| *difficulty)
.ok_or_else(|| format!("Unrecognised difficulty '{}'", text))?;
Ok(difficulty)
}
// }}}
// {{{ Read jacket
pub fn read_jacket<'a>(
&mut self,
ctx: &'a UserContext,
image: &DynamicImage,
) -> Result<(&'a Jacket, &'a CachedSong), Error> {
let rect =
RelativeRect::from_aspect_ratio(ImageDimensions::from_image(image), jacket_rects())
.ok_or_else(|| "Could not find jacket area in picture")?
.to_absolute();
let cropped = image.view(rect.x, rect.y, rect.width, rect.height);
let (distance, jacket) = ctx
.jacket_cache
.recognise(&*cropped)
.ok_or_else(|| "Could not recognise jacket")?;
if distance > 100.0 {
// Save image to be sent to discord
self.crop_image_to_bytes(&image, rect)?;
Err("No known jacket looks like this")?;
}
let song = ctx
.song_cache
.lookup(jacket.song_id)
.ok_or_else(|| format!("Could not find song with id {}", jacket.song_id))?;
Ok((jacket, song))
}
// }}}
} }
// }}} // }}}

View file

@ -1,6 +1,6 @@
use crate::context::{Context, Error}; use crate::context::{Context, Error};
#[derive(Debug, Clone, sqlx::FromRow)] #[derive(Debug, Clone)]
pub struct User { pub struct User {
pub id: u32, pub id: u32,
pub discord_id: String, pub discord_id: String,
@ -10,11 +10,14 @@ pub struct User {
impl User { impl User {
pub async fn from_context(ctx: &Context<'_>) -> Result<Self, Error> { pub async fn from_context(ctx: &Context<'_>) -> Result<Self, Error> {
let id = ctx.author().id.get().to_string(); let id = ctx.author().id.get().to_string();
let user = sqlx::query_as("SELECT * FROM users WHERE discord_id = ?") let user = sqlx::query!("SELECT * FROM users WHERE discord_id = ?", id)
.bind(id)
.fetch_one(&ctx.data().db) .fetch_one(&ctx.data().db)
.await?; .await?;
Ok(user) Ok(User {
id: user.id as u32,
discord_id: user.discord_id,
nickname: user.nickname,
})
} }
} }