1
Fork 0

Implement migrations, and switch from sqlx to rusqlite

This commit is contained in:
prescientmoon 2024-08-22 22:11:21 +02:00
parent 7cdc3a2755
commit fee7fe77f8
Signed by: prescientmoon
SSH key fingerprint: SHA256:UUF9JT2s8Xfyv76b8ZuVL7XrmimH4o49p4b+iexbVH4
17 changed files with 629 additions and 1044 deletions

View file

@ -1,5 +1,4 @@
use image::RgbaImage;
use sqlx::query;
use crate::{
assets::get_data_dir,
@ -122,16 +121,8 @@ impl GoalStats {
user: &User,
scoring_system: ScoringSystem,
) -> Result<Self, Error> {
let plays = get_best_plays(
&ctx.db,
&ctx.song_cache,
user.id,
scoring_system,
0,
usize::MAX,
None,
)
.await??;
let plays = get_best_plays(ctx, user.id, scoring_system, 0, usize::MAX, None)??;
let conn = ctx.db.get()?;
// {{{ PM count
let pm_count = plays
@ -142,30 +133,31 @@ impl GoalStats {
.count();
// }}}
// {{{ Play count
let play_count = query!(
"SELECT count() as count FROM plays WHERE user_id=?",
user.id
)
.fetch_one(&ctx.db)
.await?
.count as usize;
let play_count = conn
.prepare_cached("SELECT count() as count FROM plays WHERE user_id=?")?
.query_row([user.id], |row| row.get(0))?;
// }}}
// {{{ Peak ptt
let peak_ptt = query!(
"
let peak_ptt = conn
.prepare_cached(
"
SELECT s.creation_ptt
FROM plays p
JOIN scores s ON s.play_id = p.id
WHERE user_id = ?
AND scoring_system = ?
ORDER BY s.creation_ptt DESC
LIMIT 1
",
user.id,
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[scoring_system.to_index()]
)
.fetch_one(&ctx.db)
.await?
.creation_ptt
.ok_or_else(|| "No ptt history data found")? as u32;
)?
.query_row(
(
user.id,
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[scoring_system.to_index()],
),
|row| row.get(0),
)
.map_err(|_| "No ptt history data found")?;
// }}}
// {{{ Peak PM relay
let peak_pm_relay = {

View file

@ -1,12 +1,15 @@
use std::{fmt::Display, num::NonZeroU16, path::PathBuf};
use image::{ImageBuffer, Rgb};
use sqlx::SqlitePool;
use rusqlite::types::{FromSql, FromSqlError, FromSqlResult, ValueRef};
use crate::{bitmap::Color, context::Error};
use crate::{
bitmap::Color,
context::{DbConnection, Error},
};
// {{{ Difficuly
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, sqlx::Type)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Difficulty {
PST,
PRS,
@ -29,17 +32,19 @@ impl Difficulty {
}
}
impl TryFrom<String> for Difficulty {
type Error = String;
impl FromSql for Difficulty {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let str: String = rusqlite::types::FromSql::column_result(value)?;
fn try_from(value: String) -> Result<Self, Self::Error> {
for (i, s) in Self::DIFFICULTY_SHORTHANDS.iter().enumerate() {
if value == **s {
if str == **s {
return Ok(Self::DIFFICULTIES[i]);
}
}
Err(format!("Cannot convert {} to difficulty", value))
FromSqlResult::Err(FromSqlError::Other(
format!("Cannot convert {} to difficulty", str).into(),
))
}
}
@ -120,17 +125,19 @@ impl Display for Level {
}
}
impl TryFrom<String> for Level {
type Error = String;
impl FromSql for Level {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let str: String = rusqlite::types::FromSql::column_result(value)?;
fn try_from(value: String) -> Result<Self, Self::Error> {
for (i, s) in Self::LEVEL_STRINGS.iter().enumerate() {
if value == **s {
if str == **s {
return Ok(Self::LEVELS[i]);
}
}
Err(format!("Cannot convert {} to a level", value))
FromSqlResult::Err(FromSqlError::Other(
format!("Cannot convert {} to level", str).into(),
))
}
}
// }}}
@ -152,17 +159,19 @@ impl Side {
}
}
impl TryFrom<String> for Side {
type Error = String;
impl FromSql for Side {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
let str: String = rusqlite::types::FromSql::column_result(value)?;
fn try_from(value: String) -> Result<Self, Self::Error> {
for (i, s) in Self::SIDE_STRINGS.iter().enumerate() {
if value == **s {
if str == **s {
return Ok(Self::SIDES[i]);
}
}
Err(format!("Cannot convert {} to difficulty", value))
FromSqlResult::Err(FromSqlError::Other(
format!("Cannot convert {} to side", str).into(),
))
}
}
// }}}
@ -304,44 +313,53 @@ impl SongCache {
}
// {{{ Populate cache
pub async fn new(pool: &SqlitePool) -> Result<Self, Error> {
pub fn new(conn: &DbConnection) -> Result<Self, Error> {
let conn = conn.get()?;
let mut result = Self::default();
// {{{ Songs
let songs = sqlx::query!("SELECT * FROM songs").fetch_all(pool).await?;
for song in songs {
let song = Song {
id: song.id as u32,
lowercase_title: song.title.to_lowercase(),
title: song.title,
artist: song.artist,
pack: song.pack,
bpm: song.bpm,
side: Side::try_from(song.side)?,
};
let mut query = conn.prepare_cached("SELECT * FROM songs")?;
let songs = query.query_map((), |row| {
Ok(Song {
id: row.get("id")?,
lowercase_title: row.get::<_, String>("title")?.to_lowercase(),
title: row.get("title")?,
artist: row.get("artist")?,
pack: row.get("pack")?,
bpm: row.get("bpm")?,
side: row.get("side")?,
})
})?;
for song in songs {
let song = song?;
let song_id = song.id as usize;
if song_id >= result.songs.len() {
result.songs.resize(song_id + 1, None);
}
result.songs[song_id] = Some(CachedSong::new(song));
}
// }}}
// {{{ Charts
let charts = sqlx::query!("SELECT * FROM charts").fetch_all(pool).await?;
for chart in charts {
let chart = Chart {
id: chart.id as u32,
song_id: chart.song_id as u32,
shorthand: chart.shorthand,
difficulty: Difficulty::try_from(chart.difficulty)?,
level: Level::try_from(chart.level)?,
chart_constant: chart.chart_constant as u32,
note_count: chart.note_count as u32,
let mut query = conn.prepare_cached("SELECT * FROM charts")?;
let charts = query.query_map((), |row| {
Ok(Chart {
id: row.get("id")?,
song_id: row.get("song_id")?,
shorthand: row.get("shorthand")?,
difficulty: row.get("difficulty")?,
level: row.get("level")?,
chart_constant: row.get("chart_constant")?,
note_count: row.get("note_count")?,
note_design: row.get("note_design")?,
cached_jacket: None,
note_design: chart.note_design,
};
})
})?;
for chart in charts {
let chart = chart?;
// {{{ Tie chart to song
{

View file

@ -9,14 +9,12 @@ use num::Zero;
use poise::serenity_prelude::{
Attachment, AttachmentId, CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp,
};
use sqlx::query_as;
use sqlx::{query, SqlitePool};
use rusqlite::Row;
use crate::arcaea::chart::{Chart, Song};
use crate::context::{Error, UserContext};
use crate::user::User;
use super::chart::SongCache;
use super::rating::{rating_as_fixed, rating_as_float};
use super::score::{Score, ScoringSystem};
@ -61,12 +59,14 @@ impl CreatePlay {
}
// {{{ Save
pub async fn save(self, ctx: &UserContext, user: &User, chart: &Chart) -> Result<Play, Error> {
pub fn save(self, ctx: &UserContext, user: &User, chart: &Chart) -> Result<Play, Error> {
let conn = ctx.db.get()?;
let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64);
// {{{ Save current data to play
let play = sqlx::query!(
"
let (id, created_at) = conn
.prepare_cached(
"
INSERT INTO plays(
user_id,chart_id,discord_attachment_id,
max_recall,far_notes
@ -74,88 +74,55 @@ impl CreatePlay {
VALUES(?,?,?,?,?)
RETURNING id, created_at
",
user.id,
chart.id,
attachment_id,
self.max_recall,
self.far_notes
)
.fetch_one(&ctx.db)
.await?;
)?
.query_row(
(
user.id,
chart.id,
attachment_id,
self.max_recall,
self.far_notes,
),
|row| Ok((row.get("id")?, row.get("created_at")?)),
)?;
// }}}
// {{{ Update creation ptt data
let scores = ScoreCollection::from_standard_score(self.score, chart);
for system in ScoringSystem::SCORING_SYSTEMS {
let i = system.to_index();
let plays = get_best_plays(&ctx.db, &ctx.song_cache, user.id, system, 30, 30, None)
.await?
.ok();
let plays = get_best_plays(ctx, user.id, system, 30, 30, None)?.ok();
let creation_ptt: Option<_> = try { rating_as_fixed(compute_b30_ptt(system, &plays?)) };
query!(
conn.prepare_cached(
"
INSERT INTO scores(play_id, score, creation_ptt, scoring_system)
VALUES (?,?,?,?)
",
play.id,
)?
.execute((
id,
scores.0[i].0,
creation_ptt,
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[i]
)
.execute(&ctx.db)
.await?;
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[i],
))?;
}
// }}}
Ok(Play {
id: play.id as u32,
created_at: play.created_at,
id,
created_at,
scores,
chart_id: chart.id,
user_id: user.id,
scores,
max_recall: self.max_recall,
far_notes: self.far_notes,
})
}
// }}}
}
// }}}
// {{{ DbPlay
/// Construct a `Play` from a sqlite return record.
#[macro_export]
macro_rules! play_from_db_record {
($chart:expr, $record:expr) => {{
use crate::arcaea::play::{Play, ScoreCollection};
use crate::arcaea::score::Score;
Play {
id: $record.id as u32,
chart_id: $record.chart_id as u32,
user_id: $record.user_id as u32,
scores: ScoreCollection::from_standard_score(Score($record.score as u32), $chart),
max_recall: $record.max_recall.map(|r| r as u32),
far_notes: $record.far_notes.map(|r| r as u32),
created_at: $record.created_at,
}
}};
}
/// Typed version of the input to the macro above.
/// Useful when using the non-macro version of the sqlx functions.
#[derive(Debug, sqlx::FromRow)]
pub struct DbPlay {
pub id: i64,
pub chart_id: i64,
pub user_id: i64,
pub created_at: chrono::NaiveDateTime,
// Score details
pub max_recall: Option<i64>,
pub far_notes: Option<i64>,
pub score: i64,
}
// }}}
// {{{ Score data
#[derive(Debug, Clone, Copy)]
@ -185,6 +152,20 @@ pub struct Play {
}
impl Play {
// {{{ Row parsing
#[inline]
pub fn from_sql(chart: &Chart, row: &Row) -> Result<Self, rusqlite::Error> {
Ok(Play {
id: row.get("id")?,
chart_id: row.get("chart_id")?,
user_id: row.get("user_id")?,
created_at: row.get("created_at")?,
max_recall: row.get("max_recall")?,
far_notes: row.get("far_notes")?,
scores: ScoreCollection::from_standard_score(Score(row.get("score")?), chart),
})
}
// }}}
// {{{ Query the underlying score
#[inline]
pub fn score(&self, system: ScoringSystem) -> Score {
@ -272,9 +253,9 @@ impl Play {
/// Creates a discord embed for this play.
///
/// The `index` variable is only used to create distinct filenames.
pub async fn to_embed(
pub fn to_embed(
&self,
db: &SqlitePool,
ctx: &UserContext,
user: &User,
song: &Song,
chart: &Chart,
@ -282,33 +263,28 @@ impl Play {
author: Option<&poise::serenity_prelude::User>,
) -> Result<(CreateEmbed, Option<CreateAttachment>), Error> {
// {{{ Get previously best score
let prev_play = query!(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score
FROM plays p
JOIN scores s ON s.play_id = p.id
WHERE s.scoring_system='standard'
AND p.user_id=?
AND p.chart_id=?
AND p.created_at<?
ORDER BY s.score DESC
LIMIT 1
",
user.id,
chart.id,
self.created_at
)
.fetch_optional(db)
.await
.map_err(|_| {
format!(
"Could not find any scores for {} [{:?}]",
song.title, chart.difficulty
)
})?
.map(|p| play_from_db_record!(chart, p));
let prev_play = ctx
.db
.get()?
.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score
FROM plays p
JOIN scores s ON s.play_id = p.id
WHERE s.scoring_system='standard'
AND p.user_id=?
AND p.chart_id=?
AND p.created_at<?
ORDER BY s.score DESC
LIMIT 1
",
)?
.query_row((user.id, chart.id, self.created_at), |row| {
Self::from_sql(chart, row)
})
.ok();
let prev_score = prev_play.as_ref().map(|p| p.score(ScoringSystem::Standard));
let prev_zeta_score = prev_play.as_ref().map(|p| p.score(ScoringSystem::EX));
@ -408,38 +384,47 @@ impl Play {
// {{{ General functions
pub type PlayCollection<'a> = Vec<(Play, &'a Song, &'a Chart)>;
pub async fn get_best_plays<'a>(
db: &SqlitePool,
song_cache: &'a SongCache,
pub fn get_best_plays<'a>(
ctx: &'a UserContext,
user_id: u32,
scoring_system: ScoringSystem,
min_amount: usize,
max_amount: usize,
before: Option<NaiveDateTime>,
) -> Result<Result<PlayCollection<'a>, String>, Error> {
let conn = ctx.db.get()?;
// {{{ DB data fetching
let plays: Vec<DbPlay> = query_as(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score,
MAX(s.score) as _cscore
-- ^ This is only here to make sqlite pick the correct row for the bare columns
FROM plays p
JOIN scores s ON s.play_id = p.id
JOIN scores cs ON cs.play_id = p.id
WHERE s.scoring_system='standard'
AND cs.scoring_system=?
AND p.user_id=?
AND p.created_at<=?
GROUP BY p.chart_id
",
)
.bind(ScoringSystem::SCORING_SYSTEM_DB_STRINGS[scoring_system.to_index()])
.bind(user_id)
.bind(before.unwrap_or_else(|| Utc::now().naive_utc()))
.fetch_all(db)
.await?;
let mut plays = conn
.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score,
MAX(cs.score) as _cscore
-- ^ This is only here to make sqlite pick the correct row for the bare columns
FROM plays p
JOIN scores s ON s.play_id = p.id
JOIN scores cs ON cs.play_id = p.id
WHERE s.scoring_system='standard'
AND cs.scoring_system=?
AND p.user_id=?
AND p.created_at<=?
GROUP BY p.chart_id
",
)?
.query_and_then(
(
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[scoring_system.to_index()],
user_id,
before.unwrap_or_else(|| Utc::now().naive_utc()),
),
|row| {
let (song, chart) = ctx.song_cache.lookup_chart(row.get("chart_id")?)?;
let play = Play::from_sql(chart, row)?;
Ok((play, song, chart))
},
)?
.collect::<Result<Vec<_>, Error>>()?;
// }}}
if plays.len() < min_amount {
@ -450,17 +435,6 @@ pub async fn get_best_plays<'a>(
}
// {{{ B30 computation
// NOTE: we reallocate here, although we do not have much of a choice,
// unless we want to be lazy about things
let mut plays: Vec<(Play, &Song, &Chart)> = plays
.into_iter()
.map(|play| {
let (song, chart) = song_cache.lookup_chart(play.chart_id as u32)?;
let play = play_from_db_record!(chart, play);
Ok((play, song, chart))
})
.collect::<Result<Vec<_>, Error>>()?;
plays.sort_by_key(|(play, _, chart)| -play.play_rating(scoring_system, chart.chart_constant));
plays.truncate(max_amount);
// }}}
@ -480,7 +454,8 @@ pub fn compute_b30_ptt(scoring_system: ScoringSystem, plays: &PlayCollection<'_>
// }}}
// {{{ Maintenance functions
pub async fn generate_missing_scores(ctx: &UserContext) -> Result<(), Error> {
let plays = query!(
let conn = ctx.db.get()?;
let mut query = conn.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
@ -489,53 +464,44 @@ pub async fn generate_missing_scores(ctx: &UserContext) -> Result<(), Error> {
JOIN scores s ON s.play_id = p.id
WHERE s.scoring_system='standard'
ORDER BY p.created_at ASC
"
)
// Can't use the stream based version because of db locking...
.fetch_all(&ctx.db)
.await?;
",
)?;
let plays = query.query_and_then((), |row| -> Result<_, Error> {
let (_, chart) = ctx.song_cache.lookup_chart(row.get("chart_id")?)?;
let play = Play::from_sql(chart, row)?;
Ok(play)
})?;
let mut i = 0;
for play in plays {
let (_, chart) = ctx.song_cache.lookup_chart(play.chart_id as u32)?;
let play = play_from_db_record!(chart, play);
let play = play?;
for system in ScoringSystem::SCORING_SYSTEMS {
let i = system.to_index();
let plays = get_best_plays(
&ctx.db,
&ctx.song_cache,
play.user_id,
system,
30,
30,
Some(play.created_at),
)
.await?
.ok();
let plays =
get_best_plays(&ctx, play.user_id, system, 30, 30, Some(play.created_at))?.ok();
let creation_ptt: Option<_> = try { rating_as_fixed(compute_b30_ptt(system, &plays?)) };
let raw_score = play.scores.0[i].0;
query!(
conn.prepare_cached(
"
INSERT INTO scores(play_id, score, creation_ptt, scoring_system)
VALUES ($1, $2, $3, $4)
ON CONFLICT(play_id, scoring_system)
DO UPDATE SET
score=$2, creation_ptt=$3
WHERE play_id = $1
AND scoring_system = $4
INSERT INTO scores(play_id, score, creation_ptt, scoring_system)
VALUES ($1, $2, $3, $4)
ON CONFLICT(play_id, scoring_system)
DO UPDATE SET
score=$2, creation_ptt=$3
WHERE play_id = $1
AND scoring_system = $4
",
)?
.execute((
play.id,
raw_score,
creation_ptt,
ScoringSystem::SCORING_SYSTEM_DB_STRINGS[i],
)
.execute(&ctx.db)
.await?;
))?;
}
i += 1;

View file

@ -1,10 +1,9 @@
use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage};
use sqlx::query;
use crate::{
arcaea::chart::Side,
arcaea::{chart::Side, play::Play},
context::{Context, Error},
get_user, play_from_db_record,
get_user,
recognition::fuzzy_song_name::guess_song_and_chart,
};
use std::io::Cursor;
@ -23,7 +22,7 @@ use poise::CreateReply;
use crate::{
arcaea::score::{Score, ScoringSystem},
user::discord_it_to_discord_user,
user::discord_id_to_discord_user,
};
// {{{ Top command
@ -55,16 +54,18 @@ async fn info(
None => None,
};
let play_count = query!(
"
let play_count: usize = ctx
.data()
.db
.get()?
.prepare_cached(
"
SELECT COUNT(*) as count
FROM plays
WHERE chart_id=?
",
chart.id
)
.fetch_one(&ctx.data().db)
.await?;
",
)?
.query_row([chart.id], |row| row.get(0))?;
let mut embed = CreateEmbed::default()
.title(format!(
@ -77,7 +78,7 @@ async fn info(
format!("{:.1}", chart.chart_constant as f32 / 100.0),
true,
)
.field("Total plays", format!("{}", play_count.count), true)
.field("Total plays", format!("{play_count}"), true)
.field("BPM", &song.bpm, true)
.field("Side", Side::SIDE_STRINGS[song.side.to_index()], true)
.field("Artist", &song.title, true);
@ -117,42 +118,40 @@ async fn best(
let user = get_user!(&ctx);
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?;
let play = query!(
"
SELECT
let play = ctx
.data()
.db
.get()?
.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score
FROM plays p
JOIN scores s ON s.play_id = p.id
WHERE s.scoring_system='standard'
AND p.user_id=?
AND p.chart_id=?
ORDER BY s.score DESC
LIMIT 1
",
user.id,
chart.id
)
.fetch_one(&ctx.data().db)
.await
.map_err(|_| {
format!(
"Could not find any scores for {} [{:?}]",
song.title, chart.difficulty
)
})?;
let play = play_from_db_record!(chart, play);
FROM plays p
JOIN scores s ON s.play_id = p.id
WHERE s.scoring_system='standard'
AND p.user_id=?
AND p.chart_id=?
ORDER BY s.score DESC
LIMIT 1
",
)?
.query_row((user.id, chart.id), |row| Play::from_sql(chart, row))
.map_err(|_| {
format!(
"Could not find any scores for {} [{:?}]",
song.title, chart.difficulty
)
})?;
let (embed, attachment) = play
.to_embed(
&ctx.data().db,
&user,
&song,
&chart,
0,
Some(&discord_it_to_discord_user(&ctx, &user.discord_id).await?),
)
.await?;
let (embed, attachment) = play.to_embed(
ctx.data(),
&user,
song,
chart,
0,
Some(&discord_id_to_discord_user(&ctx, &user.discord_id).await?),
)?;
ctx.channel_id()
.send_files(ctx.http(), attachment, CreateMessage::new().embed(embed))
@ -177,8 +176,12 @@ async fn plot(
let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?;
// SAFETY: we limit the amount of plotted plays to 1000.
let plays = query!(
"
let plays = ctx
.data()
.db
.get()?
.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score
@ -190,11 +193,9 @@ async fn plot(
ORDER BY s.score DESC
LIMIT 1000
",
user.id,
chart.id
)
.fetch_all(&ctx.data().db)
.await?;
)?
.query_map((user.id, chart.id), |row| Play::from_sql(chart, row))?
.collect::<Result<Vec<_>, _>>()?;
if plays.len() == 0 {
ctx.reply(format!(
@ -209,7 +210,7 @@ async fn plot(
let max_time = plays.iter().map(|p| p.created_at).max().unwrap();
let mut min_score = plays
.iter()
.map(|p| play_from_db_record!(chart, p).score(scoring_system))
.map(|p| p.score(scoring_system))
.min()
.unwrap()
.0 as i64;
@ -266,7 +267,7 @@ async fn plot(
.map(|play| {
(
play.created_at.and_utc().timestamp_millis(),
play_from_db_record!(chart, play).score(scoring_system),
play.score(scoring_system),
)
})
.collect();

View file

@ -1,16 +1,15 @@
use std::time::Instant;
use crate::arcaea::play::CreatePlay;
use crate::arcaea::play::{CreatePlay, Play};
use crate::arcaea::score::Score;
use crate::context::{Context, Error};
use crate::recognition::recognize::{ImageAnalyzer, ScoreKind};
use crate::user::{discord_it_to_discord_user, User};
use crate::{edit_reply, get_user, play_from_db_record, timed};
use crate::user::{discord_id_to_discord_user, User};
use crate::{edit_reply, get_user, timed};
use image::DynamicImage;
use poise::serenity_prelude::futures::future::join_all;
use poise::serenity_prelude::CreateMessage;
use poise::{serenity_prelude as serenity, CreateReply};
use sqlx::query;
// {{{ Score
/// Score management
@ -121,15 +120,13 @@ pub async fn magic(
.with_attachment(file)
.with_fars(maybe_fars)
.with_max_recall(max_recall)
.save(&ctx.data(), &user, &chart)
.await?;
.save(&ctx.data(), &user, &chart)?;
// }}}
// }}}
// {{{ Deliver embed
let (embed, attachment) = timed!("to embed", {
play.to_embed(&ctx.data().db, &user, &song, &chart, i, None)
.await?
play.to_embed(ctx.data(), &user, &song, &chart, i, None)?
});
embeds.push(embed);
@ -183,11 +180,14 @@ pub async fn delete(
let mut count = 0;
for id in ids {
let res = query!("DELETE FROM plays WHERE id=? AND user_id=?", id, user.id)
.execute(&ctx.data().db)
.await?;
let res = ctx
.data()
.db
.get()?
.prepare_cached("DELETE FROM plays WHERE id=? AND user_id=?")?
.execute((id, user.id))?;
if res.rows_affected() == 0 {
if res == 0 {
ctx.reply(format!("No play with id {} found", id)).await?;
} else {
count += 1;
@ -216,36 +216,38 @@ pub async fn show(
let mut embeds = Vec::with_capacity(ids.len());
let mut attachments = Vec::with_capacity(ids.len());
let conn = ctx.data().db.get()?;
for (i, id) in ids.iter().enumerate() {
let res = query!(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score,
u.discord_id
FROM plays p
JOIN scores s ON s.play_id = p.id
JOIN users u ON p.user_id = u.id
WHERE s.scoring_system='standard'
AND p.id=?
ORDER BY s.score DESC
LIMIT 1
",
id
)
.fetch_one(&ctx.data().db)
.await
.map_err(|_| format!("Could not find play with id {}", id))?;
let (song, chart, play, discord_id) = conn
.prepare_cached(
"
SELECT
p.id, p.chart_id, p.user_id, p.created_at,
p.max_recall, p.far_notes, s.score,
u.discord_id
FROM plays p
JOIN scores s ON s.play_id = p.id
JOIN users u ON p.user_id = u.id
WHERE s.scoring_system='standard'
AND p.id=?
ORDER BY s.score DESC
LIMIT 1
",
)?
.query_and_then([id], |row| -> Result<_, Error> {
let (song, chart) = ctx.data().song_cache.lookup_chart(row.get("chart_id")?)?;
let play = Play::from_sql(chart, row)?;
let discord_id = row.get::<_, String>("discord_id")?;
Ok((song, chart, play, discord_id))
})?
.next()
.ok_or_else(|| format!("Could not find play with id {}", id))??;
let (song, chart) = ctx.data().song_cache.lookup_chart(res.chart_id as u32)?;
let play = play_from_db_record!(chart, res);
let author = discord_id_to_discord_user(&ctx, &discord_id).await?;
let user = User::by_id(ctx.data(), play.user_id)?;
let author = discord_it_to_discord_user(&ctx, &res.discord_id).await?;
let user = User::by_id(&ctx.data().db, play.user_id).await?;
let (embed, attachment) = play
.to_embed(&ctx.data().db, &user, song, chart, i, Some(&author))
.await?;
let (embed, attachment) =
play.to_embed(ctx.data(), &user, song, chart, i, Some(&author))?;
embeds.push(embed);
attachments.extend(attachment);

View file

@ -5,7 +5,6 @@ use poise::{
serenity_prelude::{CreateAttachment, CreateEmbed},
CreateReply,
};
use sqlx::query;
use crate::{
arcaea::{
@ -26,7 +25,7 @@ use crate::{
context::{Context, Error},
get_user,
logs::debug_image_log,
reply_errors,
reply_errors, timed,
user::User,
};
@ -53,20 +52,20 @@ async fn best_plays(
let user_ctx = ctx.data();
let plays = reply_errors!(
ctx,
get_best_plays(
&user_ctx.db,
&user_ctx.song_cache,
user.id,
scoring_system,
if require_full {
grid_size.0 * grid_size.1
} else {
grid_size.0 * (grid_size.1.max(1) - 1) + 1
} as usize,
(grid_size.0 * grid_size.1) as usize,
None
)
.await?
timed!("get_best_plays", {
get_best_plays(
user_ctx,
user.id,
scoring_system,
if require_full {
grid_size.0 * grid_size.1
} else {
grid_size.0 * (grid_size.1.max(1) - 1) + 1
} as usize,
(grid_size.0 * grid_size.1) as usize,
None,
)?
})
);
// {{{ Layout
@ -463,48 +462,42 @@ pub async fn bany(
#[poise::command(prefix_command, slash_command, user_cooldown = 1)]
async fn meta(ctx: Context<'_>) -> Result<(), Error> {
let user = get_user!(&ctx);
let song_count = query!("SELECT count() as count FROM songs")
.fetch_one(&ctx.data().db)
.await?
.count;
let conn = ctx.data().db.get()?;
let song_count: usize = conn
.prepare_cached("SELECT count() as count FROM songs")?
.query_row((), |row| row.get(0))?;
let chart_count = query!("SELECT count() as count FROM charts")
.fetch_one(&ctx.data().db)
.await?
.count;
let chart_count: usize = conn
.prepare_cached("SELECT count() as count FROM charts")?
.query_row((), |row| row.get(0))?;
let users_count = query!("SELECT count() as count FROM users")
.fetch_one(&ctx.data().db)
.await?
.count;
let users_count: usize = conn
.prepare_cached("SELECT count() as count FROM users")?
.query_row((), |row| row.get(0))?;
let pookie_count = query!(
"
SELECT count() as count
FROM users
WHERE is_pookie=1
"
)
.fetch_one(&ctx.data().db)
.await?
.count;
let pookie_count: usize = conn
.prepare_cached(
"
SELECT count() as count
FROM users
WHERE is_pookie=1
",
)?
.query_row((), |row| row.get(0))?;
let play_count = query!("SELECT count() as count FROM plays")
.fetch_one(&ctx.data().db)
.await?
.count;
let play_count: usize = conn
.prepare_cached("SELECT count() as count FROM plays")?
.query_row((), |row| row.get(0))?;
let your_play_count = query!(
"
let your_play_count: usize = conn
.prepare_cached(
"
SELECT count() as count
FROM plays
WHERE user_id=?
",
user.id
)
.fetch_one(&ctx.data().db)
.await?
.count;
",
)?
.query_row([user.id], |row| row.get(0))?;
let embed = CreateEmbed::default()
.title("Bot statistics")

View file

@ -14,7 +14,7 @@ macro_rules! edit_reply {
#[macro_export]
macro_rules! get_user {
($ctx:expr) => {{
crate::reply_errors!($ctx, crate::user::User::from_context($ctx).await)
crate::reply_errors!($ctx, crate::user::User::from_context($ctx))
}};
}

View file

@ -1,7 +1,6 @@
use r2d2_sqlite::SqliteConnectionManager;
use std::fs;
use sqlx::SqlitePool;
use crate::{
arcaea::{chart::SongCache, jacket::JacketCache},
assets::{get_data_dir, EXO_FONT, GEOSANS_FONT, KAZESAWA_BOLD_FONT, KAZESAWA_FONT},
@ -13,9 +12,11 @@ use crate::{
pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Context<'a> = poise::Context<'a, UserContext, Error>;
pub type DbConnection = r2d2::Pool<SqliteConnectionManager>;
// Custom user data passed to all command functions
pub struct UserContext {
pub db: SqlitePool,
pub db: DbConnection,
pub song_cache: SongCache,
pub jacket_cache: JacketCache,
pub ui_measurements: UIMeasurements,
@ -29,11 +30,11 @@ pub struct UserContext {
impl UserContext {
#[inline]
pub async fn new(db: SqlitePool) -> Result<Self, Error> {
pub async fn new(db: DbConnection) -> Result<Self, Error> {
timed!("create_context", {
fs::create_dir_all(get_data_dir())?;
let mut song_cache = timed!("make_song_cache", { SongCache::new(&db).await? });
let mut song_cache = timed!("make_song_cache", { SongCache::new(&db)? });
let jacket_cache = timed!("make_jacket_cache", { JacketCache::new(&mut song_cache)? });
let ui_measurements = timed!("read_ui_measurements", { UIMeasurements::read()? });

View file

@ -22,9 +22,16 @@ mod user;
use arcaea::play::generate_missing_scores;
use assets::get_data_dir;
use context::{Error, UserContext};
use include_dir::{include_dir, Dir};
use poise::serenity_prelude::{self as serenity};
use sqlx::sqlite::SqlitePoolOptions;
use std::{env::var, sync::Arc, time::Duration};
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite_migration::Migrations;
use std::{
env::var,
sync::{Arc, LazyLock},
time::Duration,
};
// {{{ Error handler
async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) {
@ -40,13 +47,27 @@ async fn on_error(error: poise::FrameworkError<'_, UserContext, Error>) {
#[tokio::main]
async fn main() {
let pool = SqlitePoolOptions::new()
.connect(&format!(
"sqlite://{}/db.sqlite",
get_data_dir().to_str().unwrap()
))
.await
.unwrap();
let pool = Pool::new(
SqliteConnectionManager::file(&format!("{}/db.sqlite", get_data_dir().to_str().unwrap()))
.with_init(|conn| {
static MIGRATIONS_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/migrations");
static MIGRATIONS: LazyLock<Migrations> = LazyLock::new(|| {
timed!("create_migration_structure", {
Migrations::from_directory(&MIGRATIONS_DIR)
.expect("Could not load migrations")
})
});
timed!("run_migrations", {
MIGRATIONS
.to_latest(conn)
.expect("Could not run migrations");
});
Ok(())
}),
)
.expect("Could not open sqlite database.");
// {{{ Poise options
let options = poise::FrameworkOptions {

View file

@ -1,9 +1,9 @@
use std::str::FromStr;
use poise::serenity_prelude::UserId;
use sqlx::SqlitePool;
use rusqlite::Row;
use crate::context::{Context, Error};
use crate::context::{Context, Error, UserContext};
#[derive(Debug, Clone)]
pub struct User {
@ -13,35 +13,44 @@ pub struct User {
}
impl User {
pub async fn from_context(ctx: &Context<'_>) -> Result<Self, Error> {
let id = ctx.author().id.get().to_string();
let user = sqlx::query!("SELECT * FROM users WHERE discord_id = ?", id)
.fetch_one(&ctx.data().db)
.await
.map_err(|_| "You are not an user in my database, sowwy ^~^")?;
Ok(User {
id: user.id as u32,
discord_id: user.discord_id,
is_pookie: user.is_pookie,
#[inline]
fn from_row<'a, 'b>(row: &'a Row<'b>) -> Result<Self, rusqlite::Error> {
Ok(Self {
id: row.get("id")?,
discord_id: row.get("discord_id")?,
is_pookie: row.get("is_pookie")?,
})
}
pub async fn by_id(db: &SqlitePool, id: u32) -> Result<Self, Error> {
let user = sqlx::query!("SELECT * FROM users WHERE id = ?", id)
.fetch_one(db)
.await?;
pub fn from_context(ctx: &Context<'_>) -> Result<Self, Error> {
let id = ctx.author().id.get().to_string();
let user = ctx
.data()
.db
.get()?
.prepare_cached("SELECT * FROM users WHERE discord_id = ?")?
.query_map([id], Self::from_row)?
.next()
.ok_or_else(|| "You are not an user in my database, sowwy ^~^")??;
Ok(User {
id: user.id as u32,
discord_id: user.discord_id,
is_pookie: user.is_pookie,
})
Ok(user)
}
pub fn by_id(ctx: &UserContext, id: u32) -> Result<Self, Error> {
let user = ctx
.db
.get()?
.prepare_cached("SELECT * FROM users WHERE id = ?")?
.query_map([id], Self::from_row)?
.next()
.ok_or_else(|| "You are not an user in my database, sowwy ^~^")??;
Ok(user)
}
}
#[inline]
pub async fn discord_it_to_discord_user(
pub async fn discord_id_to_discord_user(
&ctx: &Context<'_>,
discord_id: &str,
) -> Result<poise::serenity_prelude::User, Error> {