From eec8d4f964bbc8b0f99b28cea6a3d1bc360c029f Mon Sep 17 00:00:00 2001 From: prescientmoon Date: Thu, 8 Aug 2024 23:26:13 +0200 Subject: [PATCH] Refactor a huge amount of code! Signed-off-by: prescientmoon --- Cargo.lock | 277 +++---- Cargo.toml | 6 +- data/ui.txt | 18 +- flake.nix | 78 +- src/{ => arcaea}/chart.rs | 0 src/{ => arcaea}/jacket.rs | 6 +- src/arcaea/mod.rs | 4 + src/arcaea/play.rs | 371 +++++++++ src/arcaea/score.rs | 348 ++++++++ src/assets.rs | 2 +- src/bitmap.rs | 2 +- src/commands/chart.rs | 4 +- src/commands/mod.rs | 1 + src/commands/score.rs | 357 ++------ src/commands/stats.rs | 38 +- src/commands/utils.rs | 24 + src/context.rs | 5 +- src/main.rs | 9 +- src/ocr/mod.rs | 1 - src/recognition/fuzzy_song_name.rs | 127 +++ src/recognition/mod.rs | 3 + src/recognition/recognize.rs | 495 +++++++++++ src/{ocr => recognition}/ui.rs | 2 - src/score.rs | 1235 ---------------------------- src/{image.rs => transform.rs} | 0 25 files changed, 1627 insertions(+), 1786 deletions(-) rename src/{ => arcaea}/chart.rs (100%) rename src/{ => arcaea}/jacket.rs (98%) create mode 100644 src/arcaea/mod.rs create mode 100644 src/arcaea/play.rs create mode 100644 src/arcaea/score.rs create mode 100644 src/commands/utils.rs delete mode 100644 src/ocr/mod.rs create mode 100644 src/recognition/fuzzy_song_name.rs create mode 100644 src/recognition/mod.rs create mode 100644 src/recognition/recognize.rs rename src/{ocr => recognition}/ui.rs (99%) delete mode 100644 src/score.rs rename src/{image.rs => transform.rs} (100%) diff --git a/Cargo.lock b/Cargo.lock index f17de62..209c544 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", "once_cell", "version_check", "zerocopy", @@ -180,28 +179,6 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" -[[package]] -name = "bindgen" -version = "0.64.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4" -dependencies = [ - "bitflags 1.3.2", - "cexpr", - "clang-sys", - "lazy_static", - "lazycell", - "log", - "peeking_take_while", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn 1.0.109", - "which", -] - [[package]] name = "bit_field" version = "0.10.2" @@ -322,15 +299,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - [[package]] name = "cfg-expr" version = "0.15.8" @@ -362,23 +330,21 @@ dependencies = [ "windows-targets 0.52.5", ] -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading", -] - [[package]] name = "color_quant" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -742,9 +708,14 @@ dependencies = [ [[package]] name = "event-listener" -version = "2.5.3" +version = "5.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] [[package]] name = "exr" @@ -1095,22 +1066,13 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" dependencies = [ "hashbrown", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -dependencies = [ - "unicode-segmentation", -] - [[package]] name = "heck" version = "0.5.0" @@ -1239,6 +1201,16 @@ dependencies = [ "tokio-rustls 0.24.1", ] +[[package]] +name = "hypertesseract" +version = "0.1.0" +source = "git+https://github.com/BlueGhostGH/hypertesseract.git?rev=78dd8ab#78dd8ab1bbab9d7985959a5a8ac2746bce17ff5c" +dependencies = [ + "image 0.25.2", + "sys", + "thin", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1294,12 +1266,12 @@ dependencies = [ [[package]] name = "image" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" +checksum = "99314c8a2152b8ddb211f924cdae532d8c5e4c8bb54728e12fff1b0cd5963a10" dependencies = [ "bytemuck", - "byteorder", + "byteorder-lite", "color_quant", "exr", "gif 0.13.1", @@ -1406,40 +1378,12 @@ dependencies = [ "spin 0.5.2", ] -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "lebe" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" -[[package]] -name = "leptonica-plumbing" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7a74c43d6f090d39158d233f326f47cd8bba545217595c93662b4e31156f42" -dependencies = [ - "leptonica-sys", - "libc", - "thiserror", -] - -[[package]] -name = "leptonica-sys" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c924779fadc73838b9390ddda5fc1939f844fb43bd44ef6794c32bd6e52238a" -dependencies = [ - "bindgen", - "pkg-config", - "vcpkg", -] - [[package]] name = "libc" version = "0.2.155" @@ -1485,9 +1429,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.27.0" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" dependencies = [ "cc", "pkg-config", @@ -1764,6 +1708,24 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1812,12 +1774,6 @@ dependencies = [ "rustc_version", ] -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2283,12 +2239,6 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc_version" version = "0.4.0" @@ -2545,21 +2495,15 @@ version = "0.1.0" dependencies = [ "chrono", "freetype-rs", - "image 0.25.1", + "hypertesseract", + "image 0.25.2", "num", "plotters", "poise", "sqlx", - "tesseract", "tokio", ] -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - [[package]] name = "signature" version = "2.2.0" @@ -2614,6 +2558,9 @@ name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +dependencies = [ + "serde", +] [[package]] name = "socket2" @@ -2662,9 +2609,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" +checksum = "27144619c6e5802f1380337a209d2ac1c431002dd74c6e60aebff3c506dc4f0c" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2675,11 +2622,10 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" +checksum = "a999083c1af5b5d6c071d34a708a19ba3e02106ad82ef7bbd69f5e48266b613b" dependencies = [ - "ahash", "atoi", "byteorder", "bytes", @@ -2693,6 +2639,7 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", + "hashbrown", "hashlink", "hex", "indexmap", @@ -2715,26 +2662,26 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" +checksum = "a23217eb7d86c584b8cbe0337b9eacf12ab76fe7673c513141ec42565698bb88" dependencies = [ "proc-macro2", "quote", "sqlx-core", "sqlx-macros-core", - "syn 1.0.109", + "syn 2.0.66", ] [[package]] name = "sqlx-macros-core" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" +checksum = "1a099220ae541c5db479c6424bdf1b200987934033c2584f79a0e1693601e776" dependencies = [ "dotenvy", "either", - "heck 0.4.1", + "heck", "hex", "once_cell", "proc-macro2", @@ -2746,7 +2693,7 @@ dependencies = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", - "syn 1.0.109", + "syn 2.0.66", "tempfile", "tokio", "url", @@ -2754,12 +2701,12 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" +checksum = "5afe4c38a9b417b6a9a5eeffe7235d0a106716495536e7727d1c7f4b1ff3eba6" dependencies = [ "atoi", - "base64 0.21.7", + "base64 0.22.1", "bitflags 2.5.0", "byteorder", "bytes", @@ -2797,12 +2744,12 @@ dependencies = [ [[package]] name = "sqlx-postgres" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" +checksum = "b1dbb157e65f10dbe01f729339c06d239120221c9ad9fa0ba8408c4cc18ecf21" dependencies = [ "atoi", - "base64 0.21.7", + "base64 0.22.1", "bitflags 2.5.0", "byteorder", "chrono", @@ -2836,9 +2783,9 @@ dependencies = [ [[package]] name = "sqlx-sqlite" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" +checksum = "9b2cdd83c008a622d94499c0006d8ee5f821f36c89b7d625c900e5dc30b5c5ee" dependencies = [ "atoi", "chrono", @@ -2852,10 +2799,10 @@ dependencies = [ "log", "percent-encoding", "serde", + "serde_urlencoded", "sqlx-core", "tracing", "url", - "urlencoding", ] [[package]] @@ -2909,6 +2856,16 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sys" +version = "0.1.0" +source = "git+https://github.com/BlueGhostGH/hypertesseract.git?rev=78dd8ab#78dd8ab1bbab9d7985959a5a8ac2746bce17ff5c" +dependencies = [ + "openssl-sys", + "pkg-config", + "vcpkg", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -2937,7 +2894,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.5.0", + "heck", "pkg-config", "toml", "version-compare", @@ -2968,37 +2925,11 @@ dependencies = [ ] [[package]] -name = "tesseract" -version = "0.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "220d5c325aa2fa6656edd8924ad9a91d7ac7b5e998fe0f083a84f7f06ec9fda7" +name = "thin" +version = "0.1.0" +source = "git+https://github.com/BlueGhostGH/hypertesseract.git?rev=78dd8ab#78dd8ab1bbab9d7985959a5a8ac2746bce17ff5c" dependencies = [ - "tesseract-plumbing", - "tesseract-sys", - "thiserror", -] - -[[package]] -name = "tesseract-plumbing" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7fb02c52201d03517af73dd0a146ac62cbd6f0155ad3dc6455d0140d6112191" -dependencies = [ - "leptonica-plumbing", - "tesseract-sys", - "thiserror", -] - -[[package]] -name = "tesseract-sys" -version = "0.5.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd33f6f216124cfaf0fa86c2c0cdf04da39b6257bd78c5e44fa4fa98c3a5857b" -dependencies = [ - "bindgen", - "leptonica-sys", - "pkg-config", - "vcpkg", + "sys", ] [[package]] @@ -3355,12 +3286,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4259d9d4425d9f0661581b804cb85fe66a4c631cadd8f490d1c13a35d5d9291" -[[package]] -name = "unicode-segmentation" -version = "1.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" - [[package]] name = "unicode_categories" version = "0.1.1" @@ -3385,12 +3310,6 @@ dependencies = [ "serde", ] -[[package]] -name = "urlencoding" -version = "2.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" - [[package]] name = "utf-8" version = "0.7.6" @@ -3567,18 +3486,6 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix", -] - [[package]] name = "whoami" version = "1.5.1" diff --git a/Cargo.toml b/Cargo.toml index b201539..6b65011 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,12 +6,12 @@ edition = "2021" [dependencies] chrono = "0.4.38" freetype-rs = "0.36.0" -image = "0.25.1" +image = "0.25.2" num = "0.4.3" plotters = { git="https://github.com/starlitcanopy/plotters.git", rev="986cd959362a2dbec8d1b25670fd083b904d7b8c", features=["bitmap_backend"] } poise = "0.6.1" -sqlx = { version = "0.7.4", features = ["sqlite", "runtime-tokio", "chrono"] } -tesseract = "0.15.1" +sqlx = { version = "0.8.0", features = ["sqlite", "runtime-tokio", "chrono"] } +hypertesseract = { features=["image"], git="https://github.com/BlueGhostGH/hypertesseract.git", rev="78dd8ab" } tokio = {version="1.38.0", features=["rt-multi-thread"]} [profile.dev.package."*"] diff --git a/data/ui.txt b/data/ui.txt index 10b42b3..3af2c2b 100644 --- a/data/ui.txt +++ b/data/ui.txt @@ -3,12 +3,12 @@ 1037 462 476 91 Score screen — score 274 434 614 611 Score screen — jacket 378 332 161 34 Score screen — difficulty -1288 849 82 39 Score screen — pures -1288 909 82 39 Score screen — fars -1288 969 82 39 Score screen — losts +1288 846 82 45 Score screen — pures +1288 906 82 45 Score screen — fars +1288 966 82 45 Score screen — losts 584 377 74 31 Score screen — max recall 634 116 1252 102 Score screen — title - 95 256 278 49 Song select — score + 95 246 278 69 Song select — score 465 319 730 45 Song select — jacket 89 153 0 0 Song select — PST 269 153 0 0 Song select — PRS @@ -20,14 +20,14 @@ 841 682 500 94 Score screen — score 51 655 633 632 Score screen — jacket 155 546 167 38 Score screen — difficulty -1104 1087 87 34 Score screen — pures -1104 1150 87 34 Score screen — fars -1104 1212 87 34 Score screen — losts +1104 1084 87 40 Score screen — pures +1104 1147 87 40 Score screen — fars +1104 1209 87 40 Score screen — losts 364 593 87 34 Score screen — max recall 438 324 1244 104 Score screen — title - 15 264 291 52 Song select — score + 15 254 291 72 Song select — score 158 411 909 74 Song select — jacket 12 159 0 0 Song select — PST 199 159 0 0 Song select — PRS 389 159 0 0 Song select — FTR - 579 159 0 0 Song select — ETR/BYD + 581 159 0 0 Song select — ETR/BYD diff --git a/flake.nix b/flake.nix index 9dc28cf..9b80ee3 100644 --- a/flake.nix +++ b/flake.nix @@ -6,54 +6,52 @@ fenix.inputs.nixpkgs.follows = "nixpkgs"; }; - outputs = { self, ... }@inputs: - inputs.flake-utils.lib.eachSystem - (with inputs.flake-utils.lib.system; [ x86_64-linux ]) - (system: - let - pkgs = inputs.nixpkgs.legacyPackages.${system}.extend - inputs.fenix.overlays.default; - inherit (pkgs) lib; - in - { - devShell = pkgs.mkShell rec { - packages = with pkgs; [ - (fenix.complete.withComponents [ - "cargo" - "clippy" - "rust-src" - "rustc" - "rustfmt" - ]) - rust-analyzer-nightly - ruff - imagemagick - fontconfig - freetype + outputs = + { ... }@inputs: + inputs.flake-utils.lib.eachSystem (with inputs.flake-utils.lib.system; [ x86_64-linux ]) ( + system: + let + pkgs = inputs.nixpkgs.legacyPackages.${system}.extend inputs.fenix.overlays.default; + inherit (pkgs) lib; + in + { + devShell = pkgs.mkShell rec { + packages = with pkgs; [ + (fenix.complete.withComponents [ + "cargo" + "clippy" + "rust-src" + "rustc" + "rustfmt" + ]) + rust-analyzer-nightly + ruff + imagemagick + fontconfig + freetype - clang - llvmPackages.clang - pkg-config + clang + llvmPackages.clang + pkg-config - leptonica - tesseract - openssl - sqlite - ]; + leptonica + tesseract + openssl + sqlite + ]; - LD_LIBRARY_PATH = lib.makeLibraryPath packages; + LD_LIBRARY_PATH = lib.makeLibraryPath packages; - # compilation of -sys packages requires manually setting LIBCLANG_PATH - LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; - }; - }); + # compilation of -sys packages requires manually setting LIBCLANG_PATH + LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; + }; + } + ); # {{{ Caching and whatnot # TODO: persist trusted substituters file nixConfig = { - extra-substituters = [ - "https://nix-community.cachix.org" - ]; + extra-substituters = [ "https://nix-community.cachix.org" ]; extra-trusted-public-keys = [ "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" diff --git a/src/chart.rs b/src/arcaea/chart.rs similarity index 100% rename from src/chart.rs rename to src/arcaea/chart.rs diff --git a/src/jacket.rs b/src/arcaea/jacket.rs similarity index 98% rename from src/jacket.rs rename to src/arcaea/jacket.rs index 747739a..b7abec8 100644 --- a/src/jacket.rs +++ b/src/arcaea/jacket.rs @@ -4,10 +4,10 @@ use image::{imageops::FilterType, GenericImageView, Rgba}; use num::Integer; use crate::{ + arcaea::chart::{Difficulty, Jacket, SongCache}, assets::{get_assets_dir, should_skip_jacket_art}, - chart::{Difficulty, Jacket, SongCache}, context::Error, - score::guess_chart_name, + recognition::fuzzy_song_name::guess_chart_name, }; /// How many sub-segments to split each side into @@ -78,7 +78,7 @@ pub struct JacketCache { } impl JacketCache { - // {{{ Generate tree + // {{{ Generate // This is a bit inefficient (using a hash set), but only runs once pub fn new(data_dir: &PathBuf, song_cache: &mut SongCache) -> Result { let jacket_dir = data_dir.join("jackets"); diff --git a/src/arcaea/mod.rs b/src/arcaea/mod.rs new file mode 100644 index 0000000..a86bb49 --- /dev/null +++ b/src/arcaea/mod.rs @@ -0,0 +1,4 @@ +pub mod chart; +pub mod jacket; +pub mod play; +pub mod score; diff --git a/src/arcaea/play.rs b/src/arcaea/play.rs new file mode 100644 index 0000000..993f523 --- /dev/null +++ b/src/arcaea/play.rs @@ -0,0 +1,371 @@ +use std::str::FromStr; + +use num::traits::Euclid; +use poise::serenity_prelude::{ + Attachment, AttachmentId, CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp, +}; +use sqlx::{query_as, SqlitePool}; + +use crate::arcaea::chart::{Chart, Song}; +use crate::context::{Error, UserContext}; +use crate::user::User; + +use super::score::Score; + +// {{{ Create play +#[derive(Debug, Clone)] +pub struct CreatePlay { + chart_id: u32, + user_id: u32, + discord_attachment_id: Option, + + // Actual score data + score: Score, + zeta_score: Score, + + // Optional score details + max_recall: Option, + far_notes: Option, + + // Creation data + creation_ptt: Option, + creation_zeta_ptt: Option, +} + +impl CreatePlay { + #[inline] + pub fn new(score: Score, chart: &Chart, user: &User) -> Self { + Self { + chart_id: chart.id, + user_id: user.id, + discord_attachment_id: None, + score, + zeta_score: score.to_zeta(chart.note_count as u32), + max_recall: None, + far_notes: None, + // TODO: populate these + creation_ptt: None, + creation_zeta_ptt: None, + } + } + + #[inline] + pub fn with_attachment(mut self, attachment: &Attachment) -> Self { + self.discord_attachment_id = Some(attachment.id); + self + } + + #[inline] + pub fn with_fars(mut self, far_count: Option) -> Self { + self.far_notes = far_count; + self + } + + #[inline] + pub fn with_max_recall(mut self, max_recall: Option) -> Self { + self.max_recall = max_recall; + self + } + + // {{{ Save + pub async fn save(self, ctx: &UserContext) -> Result { + let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64); + let play = sqlx::query!( + " + INSERT INTO plays( + user_id,chart_id,discord_attachment_id, + score,zeta_score,max_recall,far_notes + ) + VALUES(?,?,?,?,?,?,?) + RETURNING id, created_at + ", + self.user_id, + self.chart_id, + attachment_id, + self.score.0, + self.zeta_score.0, + self.max_recall, + self.far_notes + ) + .fetch_one(&ctx.db) + .await?; + + Ok(Play { + id: play.id as u32, + created_at: play.created_at, + chart_id: self.chart_id, + user_id: self.user_id, + discord_attachment_id: self.discord_attachment_id, + score: self.score, + zeta_score: self.zeta_score, + max_recall: self.max_recall, + far_notes: self.far_notes, + creation_ptt: self.creation_ptt, + creation_zeta_ptt: self.creation_zeta_ptt, + }) + } + // }}} +} +// }}} +// {{{ DbPlay +/// Version of `Play` matching the format sqlx expects +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct DbPlay { + pub id: i64, + pub chart_id: i64, + pub user_id: i64, + pub discord_attachment_id: Option, + pub score: i64, + pub zeta_score: i64, + pub max_recall: Option, + pub far_notes: Option, + pub created_at: chrono::NaiveDateTime, + pub creation_ptt: Option, + pub creation_zeta_ptt: Option, +} + +impl DbPlay { + #[inline] + pub fn to_play(self) -> Play { + Play { + id: self.id as u32, + chart_id: self.chart_id as u32, + user_id: self.user_id as u32, + score: Score(self.score as u32), + zeta_score: Score(self.zeta_score as u32), + max_recall: self.max_recall.map(|r| r as u32), + far_notes: self.far_notes.map(|r| r as u32), + created_at: self.created_at, + discord_attachment_id: self + .discord_attachment_id + .and_then(|s| AttachmentId::from_str(&s).ok()), + creation_ptt: self.creation_ptt.map(|r| r as u32), + creation_zeta_ptt: self.creation_zeta_ptt.map(|r| r as u32), + } + } +} +// }}} +// {{{ Play +#[derive(Debug, Clone)] +pub struct Play { + pub id: u32, + pub chart_id: u32, + pub user_id: u32, + + #[allow(unused)] + pub discord_attachment_id: Option, + + // Actual score data + pub score: Score, + pub zeta_score: Score, + + // Optional score details + pub max_recall: Option, + pub far_notes: Option, + + // Creation data + pub created_at: chrono::NaiveDateTime, + + #[allow(unused)] + pub creation_ptt: Option, + + #[allow(unused)] + pub creation_zeta_ptt: Option, +} + +impl Play { + // {{{ Play => distribution + pub fn distribution(&self, note_count: u32) -> Option<(u32, u32, u32, u32)> { + if let Some(fars) = self.far_notes { + let (_, shinies, units) = self.score.analyse(note_count); + let (pures, rem) = units.checked_sub(fars)?.div_rem_euclid(&2); + if rem == 1 { + println!("The impossible happened: got an invalid amount of far notes!"); + return None; + } + + let lost = note_count.checked_sub(fars + pures)?; + let non_max_pures = pures.checked_sub(shinies)?; + Some((shinies, non_max_pures, fars, lost)) + } else { + None + } + } + // }}} + // {{{ Play => status + #[inline] + pub fn status(&self, chart: &Chart) -> Option { + let score = self.score.0; + if score >= 10_000_000 { + if score > chart.note_count + 10_000_000 { + return None; + } + + let non_max_pures = (chart.note_count + 10_000_000).checked_sub(score)?; + if non_max_pures == 0 { + Some("MPM".to_string()) + } else { + Some(format!("PM (-{})", non_max_pures)) + } + } else if let Some(distribution) = self.distribution(chart.note_count) { + // if no lost notes... + if distribution.3 == 0 { + Some(format!("FR (-{}/-{})", distribution.1, distribution.2)) + } else { + Some(format!( + "C (-{}/-{}/-{})", + distribution.1, distribution.2, distribution.3 + )) + } + } else { + None + } + } + + #[inline] + pub fn short_status(&self, chart: &Chart) -> Option { + let score = self.score.0; + if score >= 10_000_000 { + let non_max_pures = (chart.note_count + 10_000_000).checked_sub(score)?; + if non_max_pures == 0 { + Some('M') + } else { + Some('P') + } + } else if let Some(distribution) = self.distribution(chart.note_count) + && distribution.3 == 0 + { + Some('F') + } else { + Some('C') + } + } + // }}} + // {{{ Play to embed + /// Creates a discord embed for this play. + /// + /// The `index` variable is only used to create distinct filenames. + pub async fn to_embed( + &self, + db: &SqlitePool, + user: &User, + song: &Song, + chart: &Chart, + index: usize, + author: Option<&poise::serenity_prelude::User>, + ) -> Result<(CreateEmbed, Option), Error> { + // {{{ Get previously best score + let previously_best = query_as!( + DbPlay, + " + SELECT * FROM plays + WHERE user_id=? + AND chart_id=? + AND created_at Some(CreateAttachment::bytes(jacket.raw, &attachement_name)), + None => None, + }; + + let mut embed = CreateEmbed::default() + .title(format!( + "{} [{:?} {}]", + &song.title, chart.difficulty, chart.level + )) + .field("Score", format!("{} (+?)", self.score), true) + .field( + "Rating", + format!( + "{:.2} (+?)", + self.score.play_rating_f32(chart.chart_constant) + ), + true, + ) + .field("Grade", format!("{}", self.score.grade()), true) + .field("ξ-Score", format!("{} (+?)", self.zeta_score), true) + // {{{ ξ-Rating + .field( + "ξ-Rating", + { + let play_rating = self.zeta_score.play_rating_f32(chart.chart_constant); + if let Some(previous) = previously_best { + let previous_play_rating = + previous.zeta_score.play_rating_f32(chart.chart_constant); + + if play_rating >= previous_play_rating { + format!( + "{:.2} (+{})", + play_rating, + play_rating - previous_play_rating + ) + } else { + format!( + "{:.2} (-{})", + play_rating, + play_rating - previous_play_rating + ) + } + } else { + format!("{:.2}", play_rating) + } + }, + true, + ) + // }}} + .field("ξ-Grade", format!("{}", self.zeta_score.grade()), true) + .field( + "Status", + self.status(chart).unwrap_or("-".to_string()), + true, + ) + .field( + "Max recall", + if let Some(max_recall) = self.max_recall { + format!("{}", max_recall) + } else { + format!("-") + }, + true, + ) + .field("ID", format!("{}", self.id), true); + + if icon_attachement.is_some() { + embed = embed.thumbnail(format!("attachment://{}", &attachement_name)); + } + + if let Some(user) = author { + let mut embed_author = CreateEmbedAuthor::new(&user.name); + if let Some(url) = user.avatar_url() { + embed_author = embed_author.icon_url(url); + } + + embed = embed + .timestamp(Timestamp::from_millis( + self.created_at.and_utc().timestamp_millis(), + )?) + .author(embed_author); + } + + Ok((embed, icon_attachement)) + } + // }}} +} +// }}} diff --git a/src/arcaea/score.rs b/src/arcaea/score.rs new file mode 100644 index 0000000..d005c3d --- /dev/null +++ b/src/arcaea/score.rs @@ -0,0 +1,348 @@ +use std::fmt::Display; + +use num::Rational64; + +use crate::context::Error; + +// {{{ Grade +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum Grade { + EXP, + EX, + AA, + A, + B, + C, + D, +} + +impl Grade { + pub const GRADE_STRINGS: [&'static str; 7] = ["EX+", "EX", "AA", "A", "B", "C", "D"]; + pub const GRADE_SHORTHANDS: [&'static str; 7] = ["exp", "ex", "aa", "a", "b", "c", "d"]; + + #[inline] + pub fn to_index(self) -> usize { + self as usize + } +} + +impl Display for Grade { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", Self::GRADE_STRINGS[self.to_index()]) + } +} +// }}} +// {{{ Score +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Score(pub u32); + +impl Score { + // {{{ Score analysis + // {{{ Mini getters + #[inline] + pub fn to_zeta(self, note_count: u32) -> Score { + self.analyse(note_count).0 + } + + #[inline] + pub fn shinies(self, note_count: u32) -> u32 { + self.analyse(note_count).1 + } + + #[inline] + pub fn units(self, note_count: u32) -> u32 { + self.analyse(note_count).2 + } + // }}} + + #[inline] + pub fn increment(note_count: u32) -> Rational64 { + Rational64::new_raw(5_000_000, note_count as i64).reduced() + } + + /// Remove the contribution made by shinies to a score. + #[inline] + pub fn forget_shinies(self, note_count: u32) -> Self { + Self( + (Self::increment(note_count) * Rational64::from_integer(self.units(note_count) as i64)) + .floor() + .to_integer() as u32, + ) + } + + /// Compute a score without making a distinction between shinies and pures. That is, the given + /// value for `pures` must refer to the sum of `pure` and `shiny` notes. + /// + /// This is the simplest way to compute a score, and is useful for error analysis. + #[inline] + pub fn compute_naive(note_count: u32, pures: u32, fars: u32) -> Self { + Self( + (Self::increment(note_count) * Rational64::from_integer((2 * pures + fars) as i64)) + .floor() + .to_integer() as u32, + ) + } + + /// Returns the zeta score, the number of shinies, and the number of score units. + /// + /// Pure (and higher) notes reward two score units, far notes reward one, and lost notes reward + /// none. + pub fn analyse(self, note_count: u32) -> (Score, u32, u32) { + // Smallest possible difference between (zeta-)scores + let increment = Self::increment(note_count); + let zeta_increment = Rational64::new_raw(2_000_000, note_count as i64).reduced(); + + let score = Rational64::from_integer(self.0 as i64); + let score_units = (score / increment).floor(); + + let non_shiny_score = (score_units * increment).floor(); + let shinies = score - non_shiny_score; + + let zeta_score_units = Rational64::from_integer(2) * score_units + shinies; + let zeta_score = Score((zeta_increment * zeta_score_units).floor().to_integer() as u32); + + ( + zeta_score, + shinies.to_integer() as u32, + score_units.to_integer() as u32, + ) + } + // }}} + // {{{ Score => Play rating + #[inline] + pub fn play_rating(self, chart_constant: u32) -> i32 { + chart_constant as i32 + + if self.0 >= 10_000_000 { + 200 + } else if self.0 >= 9_800_000 { + 100 + (self.0 as i32 - 9_800_000) / 2_000 + } else { + (self.0 as i32 - 9_500_000) / 3_000 + } + } + + #[inline] + pub fn play_rating_f32(self, chart_constant: u32) -> f32 { + (self.play_rating(chart_constant)) as f32 / 100.0 + } + // }}} + // {{{ Score => grade + #[inline] + // TODO: Perhaps make an enum for this + pub fn grade(self) -> Grade { + let score = self.0; + if score > 9900000 { + Grade::EXP + } else if score > 9800000 { + Grade::EX + } else if score > 9500000 { + Grade::AA + } else if score > 9200000 { + Grade::A + } else if score > 8900000 { + Grade::B + } else if score > 8600000 { + Grade::C + } else { + Grade::D + } + } + // }}} + // {{{ Scores & Distribution => score + pub fn resolve_ambiguities( + scores: Vec, + read_distribution: Option<(u32, u32, u32)>, + note_count: u32, + ) -> Result<(Score, Option, Option<&'static str>), Error> { + if scores.len() == 0 { + return Err("No scores in list to disambiguate from.")?; + } + + let mut no_shiny_scores: Vec<_> = scores + .iter() + .map(|score| score.forget_shinies(note_count)) + .collect(); + no_shiny_scores.sort(); + no_shiny_scores.dedup(); + + if let Some(read_distribution) = read_distribution { + let pures = read_distribution.0; + let fars = read_distribution.1; + let losts = read_distribution.2; + + // Compute score from note breakdown subpairs + let pf_score = Score::compute_naive(note_count, pures, fars); + let fl_score = Score::compute_naive( + note_count, + note_count.checked_sub(losts + fars).unwrap_or(0), + fars, + ); + let lp_score = Score::compute_naive( + note_count, + pures, + note_count.checked_sub(losts + pures).unwrap_or(0), + ); + + if no_shiny_scores.len() == 1 { + // {{{ Score is fixed, gotta figure out the exact distribution + let score = *scores.iter().max().unwrap(); + + // {{{ Look for consensus among recomputed scores + // Lemma: if two computed scores agree, then so will the third + let consensus_fars = if pf_score == fl_score { + Some(fars) + } else { + // Due to the above lemma, we know all three scores must be distinct by + // this point. + // + // Our strategy is to check which of the three scores agrees with the real + // score, and to then trust the `far` value that contributed to that pair. + let no_shiny_score = score.forget_shinies(note_count); + let pf_appears = no_shiny_score == pf_score; + let fl_appears = no_shiny_score == fl_score; + let lp_appears = no_shiny_score == lp_score; + + match (pf_appears, fl_appears, lp_appears) { + (true, false, false) => Some(fars), + (false, true, false) => Some(fars), + (false, false, true) => Some(note_count - pures - losts), + _ => None, + } + }; + // }}} + + if scores.len() == 1 { + Ok((score, consensus_fars, None)) + } else { + Ok((score, consensus_fars, Some("Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!"))) + } + + // }}} + } else { + // {{{ Score is not fixed, gotta figure out everything at once + // Some of the values in the note distribution are likely wrong (due to reading + // errors). To get around this, we take each pair from the triplet, compute the score + // it induces, and figure out if there's any consensus as to which value in the + // provided score list is the real one. + // + // Note that sometimes the note distribution cannot resolve any of the issues. This is + // usually the case when the disagreement comes from the number of shinies. + + // {{{ Look for consensus among recomputed scores + // Lemma: if two computed scores agree, then so will the third + let (trusted_pure_count, consensus_computed_score, consensus_fars) = if pf_score + == fl_score + { + (true, pf_score, fars) + } else { + // Due to the above lemma, we know all three scores must be distinct by + // this point. + // + // Our strategy is to check which of the three scores appear in the + // provided score list. + let pf_appears = no_shiny_scores.contains(&pf_score); + let fl_appears = no_shiny_scores.contains(&fl_score); + let lp_appears = no_shiny_scores.contains(&lp_score); + + match (pf_appears, fl_appears, lp_appears) { + (true, false, false) => (true, pf_score, fars), + (false, true, false) => (false, fl_score, fars), + (false, false, true) => (true, lp_score, note_count - pures - losts), + _ => Err(format!("Cannot disambiguate scores {:?}. Multiple disjoint note breakdown subpair scores appear on the possibility list", scores))? + } + }; + // }}} + // {{{ Collect all scores that agree with the consensus score. + let agreement: Vec<_> = scores + .iter() + .filter(|score| score.forget_shinies(note_count) == consensus_computed_score) + .filter(|score| { + let shinies = score.shinies(note_count); + shinies <= note_count && (!trusted_pure_count || shinies <= pures) + }) + .map(|v| *v) + .collect(); + // }}} + // {{{ Case 1: Disagreement in the amount of shinies! + if agreement.len() > 1 { + let agreement_shiny_amounts: Vec<_> = + agreement.iter().map(|v| v.shinies(note_count)).collect(); + + println!( + "Shiny count disagreement. Possible scores: {:?}. Possible shiny amounts: {:?}, Read distribution: {:?}", + scores, agreement_shiny_amounts, read_distribution + ); + + let msg = Some( + "Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!" + ); + + Ok(( + agreement.into_iter().max().unwrap(), + Some(consensus_fars), + msg, + )) + // }}} + // {{{ Case 2: Total agreement! + } else if agreement.len() == 1 { + Ok((agreement[0], Some(consensus_fars), None)) + // }}} + // {{{ Case 3: No agreement! + } else { + Err(format!("Could not disambiguate between possible scores {:?}. Note distribution does not agree with any possibility, leading to a score of {:?}.", scores, consensus_computed_score))? + } + // }}} + // }}} + } + } else { + if no_shiny_scores.len() == 1 { + if scores.len() == 1 { + Ok((scores[0], None, None)) + } else { + Ok((scores.into_iter().max().unwrap(), None, Some("Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!"))) + } + } else { + Err("Cannot disambiguate between more than one score without a note distribution.")? + } + } + } + // }}} +} + +impl Display for Score { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let score = self.0; + write!( + f, + "{}'{:0>3}'{:0>3}", + score / 1000000, + (score / 1000) % 1000, + score % 1000 + ) + } +} +// }}} +// {{{ Tests +#[cfg(test)] +mod score_tests { + use super::*; + + #[test] + fn zeta_score_consistent_with_pms() { + // note counts + for note_count in 200..=2000 { + for shiny_count in 0..=note_count { + let score = Score(10000000 + shiny_count); + let zeta_score_units = 4 * (note_count - shiny_count) + 5 * shiny_count; + let (zeta_score, computed_shiny_count, units) = score.analyse(note_count); + let expected_zeta_score = Rational64::from_integer(zeta_score_units as i64) + * Rational64::new_raw(2000000, note_count as i64).reduced(); + + assert_eq!(zeta_score, Score(expected_zeta_score.to_integer() as u32)); + assert_eq!(computed_shiny_count, shiny_count); + assert_eq!(units, 2 * note_count); + } + } + } +} +// }}} diff --git a/src/assets.rs b/src/assets.rs index aba0dd8..621243b 100644 --- a/src/assets.rs +++ b/src/assets.rs @@ -4,7 +4,7 @@ use std::{cell::RefCell, env::var, path::PathBuf, str::FromStr, sync::OnceLock}; use freetype::{Face, Library}; use image::{imageops::FilterType, ImageBuffer, Rgb, Rgba}; -use crate::chart::Difficulty; +use crate::arcaea::chart::Difficulty; #[inline] pub fn get_data_dir() -> PathBuf { diff --git a/src/bitmap.rs b/src/bitmap.rs index 7e8158b..daf3d75 100644 --- a/src/bitmap.rs +++ b/src/bitmap.rs @@ -34,7 +34,7 @@ impl Color { #[inline] pub const fn from_bytes(bytes: [u8; 4]) -> Self { - Self(bytes[0], bytes[1], bytes[1], bytes[3]) + Self(bytes[0], bytes[1], bytes[2], bytes[3]) } #[inline] diff --git a/src/commands/chart.rs b/src/commands/chart.rs index d4ca796..4675726 100644 --- a/src/commands/chart.rs +++ b/src/commands/chart.rs @@ -2,9 +2,9 @@ use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage}; use sqlx::query; use crate::{ - chart::Side, + arcaea::chart::Side, context::{Context, Error}, - score::guess_song_and_chart, + recognition::fuzzy_song_name::guess_song_and_chart, }; // {{{ Chart diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 3e75f27..8b5968b 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -3,6 +3,7 @@ use crate::context::{Context, Error}; pub mod chart; pub mod score; pub mod stats; +mod utils; // {{{ Help /// Show this help menu diff --git a/src/commands/score.rs b/src/commands/score.rs index bc048b3..7590b44 100644 --- a/src/commands/score.rs +++ b/src/commands/score.rs @@ -1,10 +1,10 @@ -use std::fmt::Display; - +use crate::arcaea::play::{CreatePlay, Play}; +use crate::arcaea::score::Score; use crate::context::{Context, Error}; -use crate::score::{CreatePlay, ImageCropper, Play, Score, ScoreKind}; +use crate::recognition::recognize::{ImageAnalyzer, ScoreKind}; use crate::user::{discord_it_to_discord_user, User}; -use image::imageops::FilterType; -use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage}; +use crate::{edit_reply, get_user}; +use poise::serenity_prelude::CreateMessage; use poise::{serenity_prelude as serenity, CreateReply}; use sqlx::query; @@ -21,46 +21,13 @@ pub async fn score(_ctx: Context<'_>) -> Result<(), Error> { } // }}} // {{{ Score magic -// {{{ Send error embed with image -async fn error_with_image( - ctx: Context<'_>, - bytes: &[u8], - filename: &str, - message: &str, - err: impl Display, -) -> Result<(), Error> { - let error_attachement = CreateAttachment::bytes(bytes, filename); - let msg = CreateMessage::default().embed( - CreateEmbed::default() - .title(message) - .attachment(filename) - .description(format!("{}", err)), - ); - - ctx.channel_id() - .send_files(ctx.http(), [error_attachement], msg) - .await?; - - Ok(()) -} -// }}} - /// Identify scores from attached images. #[poise::command(prefix_command, slash_command)] pub async fn magic( ctx: Context<'_>, #[description = "Images containing scores"] files: Vec, ) -> Result<(), Error> { - let user = match User::from_context(&ctx).await { - Ok(user) => user, - Err(_) => { - ctx.say("You are not an user in my database, sorry!") - .await?; - return Ok(()); - } - }; - - println!("Handling command from user {:?}", user.discord_id); + let user = get_user!(&ctx); if files.len() == 0 { ctx.reply("No images found attached to message").await?; @@ -71,246 +38,103 @@ pub async fn magic( .reply(format!("Processed 0/{} scores", files.len())) .await?; + let mut analyzer = ImageAnalyzer::default(); + for (i, file) in files.iter().enumerate() { if let Some(_) = file.dimensions() { - // {{{ Image pre-processing let bytes = file.download().await?; let mut image = image::load_from_memory(&bytes)?; // image = image.resize(1024, 1024, FilterType::Nearest); - // }}} - // {{{ Detection - // Create cropper and run OCR - let mut cropper = ImageCropper::default(); - let edited = CreateReply::default() - .reply(true) - .content(format!("Image {}: reading jacket", i + 1)); - handle.edit(ctx, edited).await?; + let result: Result<(), Error> = try { + // {{{ Detection + // This makes OCR more likely to work + let mut ocr_image = image.grayscale().blur(1.); - // This makes OCR more likely to work - let mut ocr_image = image.grayscale().blur(1.); + edit_reply!(ctx, handle, "Image {}: reading kind", i + 1).await?; + let kind = analyzer.read_score_kind(ctx.data(), &ocr_image)?; - // {{{ Kind - let edited = CreateReply::default() - .reply(true) - .content(format!("Image {}: reading kind", i + 1)); - handle.edit(ctx, edited).await?; + edit_reply!(ctx, handle, "Image {}: reading difficulty", i + 1).await?; + // Do not use `ocr_image` because this reads the colors + let difficulty = analyzer.read_difficulty(ctx.data(), &image, kind)?; - let kind = match cropper.read_score_kind(ctx.data(), &ocr_image) { - // {{{ OCR error handling - Err(err) => { - error_with_image( - ctx, - &cropper.bytes, - &file.filename, - "Could not read kind from picture", - &err, - ) + edit_reply!(ctx, handle, "Image {}: reading jacket", i + 1).await?; + let (song, chart) = analyzer + .read_jacket(ctx.data(), &mut image, kind, difficulty) .await?; - continue; - } - // }}} - Ok(k) => k, - }; - // }}} - // {{{ Difficulty - let edited = CreateReply::default() - .reply(true) - .content(format!("Image {}: reading difficulty", i + 1)); - handle.edit(ctx, edited).await?; + ocr_image.invert(); - // Do not use `ocr_image` because this reads the colors - let difficulty = match cropper.read_difficulty(ctx.data(), &image, kind) { - // {{{ OCR error handling - Err(err) => { - error_with_image( - ctx, - &cropper.bytes, - &file.filename, - "Could not read difficulty from picture", - &err, - ) - .await?; + let (note_distribution, max_recall) = match kind { + ScoreKind::ScoreScreen => { + edit_reply!(ctx, handle, "Image {}: reading distribution", i + 1) + .await?; + let note_distribution = + Some(analyzer.read_distribution(ctx.data(), &image)?); - continue; - } - // }}} - Ok(d) => d, - }; + edit_reply!(ctx, handle, "Image {}: reading max recall", i + 1).await?; + let max_recall = Some(analyzer.read_max_recall(ctx.data(), &image)?); - println!("{difficulty:?}"); - // }}} - // {{{ Jacket & distribution - let mut jacket_rect = None; - let song_by_jacket = cropper - .read_jacket(ctx.data(), &mut image, kind, difficulty, &mut jacket_rect) - .await; - // image.invert(); - ocr_image.invert(); - let note_distribution = match kind { - ScoreKind::ScoreScreen => Some(cropper.read_distribution(ctx.data(), &image)?), - ScoreKind::SongSelect => None, - }; - // }}} - // {{{ Title - let edited = CreateReply::default() - .reply(true) - .content(format!("Image {}: reading title", i + 1)); - handle.edit(ctx, edited).await?; - - let song_by_name = match kind { - ScoreKind::SongSelect => None, - ScoreKind::ScoreScreen => { - Some(cropper.read_song(ctx.data(), &ocr_image, difficulty)) - } - }; - - let (song, chart) = match (song_by_jacket, song_by_name) { - // {{{ Only name succeeded - (Err(err_jacket), Some(Ok(by_name))) => { - println!("Could not recognise jacket with error: {}", err_jacket); - by_name - } - // }}} - // {{{ Both succeeded - (Ok(by_jacket), Some(Ok(by_name))) => { - if by_name.0.id != by_jacket.0.id { - println!( - "Got diverging choices between '{}' and '{}'", - by_jacket.0.title, by_name.0.title - ); - }; - - by_jacket - } // }}} - // {{{ Only jacket succeeded - (Ok(by_jacket), err_name) => { - if let Some(err) = err_name { - println!("Could not read name with error: {:?}", err.unwrap_err()); + (note_distribution, max_recall) } + ScoreKind::SongSelect => (None, None), + }; - by_jacket - } - // }}} - // {{{ Both errors - (Err(err_jacket), err_name) => { - if let Some(rect) = jacket_rect { - cropper.crop_image_to_bytes(&image, rect)?; - error_with_image( - ctx, - &cropper.bytes, - &file.filename, - "Hey! I could not read the score in the provided picture.", - &format!( - "This can mean one of three things: -1. The image you provided is *not that of an Arcaea score -2. The image you provided contains a newly added chart that is not in my database yet -3. The image you provided contains character art that covers the chart name. When this happens, I try to make use of the jacket art in order to determine the chart. Contact `@prescientmoon` on discord to try and resolve the issue! + edit_reply!(ctx, handle, "Image {}: reading score", i + 1).await?; + let score_possibilities = analyzer.read_score( + ctx.data(), + Some(chart.note_count), + &ocr_image, + kind, + )?; -Nerdy info: -``` -Jacket error: {} -Title error: {:?} -```" , - err_jacket, err_name - ), - ) - .await?; - } else { - ctx.reply(format!( - "This is a weird error that should never happen... -Nerdy info: -``` -Jacket error: {} -Title error: {:?} -```", - err_jacket, err_name - )) - .await?; - } - continue; - } // }}} - }; - - println!("{}", song.title); - // }}} - // {{{ Score - let edited = CreateReply::default() - .reply(true) - .content(format!("Image {}: reading score", i + 1)); - handle.edit(ctx, edited).await?; - - let score_possibilities = match cropper.read_score( - ctx.data(), - Some(chart.note_count), - &ocr_image, - kind, - ) { - // {{{ OCR error handling - Err(err) => { - error_with_image( - ctx, - &cropper.bytes, - &file.filename, - "Could not read score from picture", - &err, - ) - .await?; - - continue; - } - // }}} - Ok(scores) => scores, - }; - // }}} - // {{{ Build play - let (score, maybe_fars, score_warning) = Score::resolve_ambiguities( - score_possibilities, - note_distribution, - chart.note_count, - ) - .map_err(|err| { - format!( - "Error occurred when disambiguating scores for '{}' [{:?}] by {}: {}", - song.title, difficulty, song.artist, err + // {{{ Build play + let (score, maybe_fars, score_warning) = Score::resolve_ambiguities( + score_possibilities, + note_distribution, + chart.note_count, ) - })?; - println!( - "Maybe fars {:?}, distribution {:?}", - maybe_fars, note_distribution - ); - let play = CreatePlay::new(score, &chart, &user) - .with_attachment(file) - .with_fars(maybe_fars) - .save(&ctx.data()) - .await?; - // }}} - // }}} - // {{{ Deliver embed - let (mut embed, attachment) = play - .to_embed(&ctx.data().db, &user, &song, &chart, i, None) - .await?; - if let Some(warning) = score_warning { - embed = embed.description(warning); - } + .map_err(|err| { + format!( + "Error occurred when disambiguating scores for '{}' [{:?}] by {}: {}", + song.title, difficulty, song.artist, err + ) + })?; - embeds.push(embed); - attachments.extend(attachment); - // }}} + let play = CreatePlay::new(score, &chart, &user) + .with_attachment(file) + .with_fars(maybe_fars) + .with_max_recall(max_recall) + .save(&ctx.data()) + .await?; + // }}} + // }}} + // {{{ Deliver embed + let (mut embed, attachment) = play + .to_embed(&ctx.data().db, &user, &song, &chart, i, None) + .await?; + + if let Some(warning) = score_warning { + embed = embed.description(warning); + } + + embeds.push(embed); + attachments.extend(attachment); + // }}} + }; + + if let Err(err) = result { + analyzer + .send_discord_error(ctx, &image, &file.filename, err) + .await?; + } } else { ctx.reply("One of the attached files is not an image!") .await?; continue; } - let edited = CreateReply::default().reply(true).content(format!( - "Processed {}/{} scores", - i + 1, - files.len() - )); - - handle.edit(ctx, edited).await?; + edit_reply!(ctx, handle, "Processed {}/{} scores", i + 1, files.len()).await?; } handle.delete(ctx).await?; @@ -330,14 +154,7 @@ pub async fn delete( ctx: Context<'_>, #[description = "Id of score to delete"] ids: Vec, ) -> Result<(), Error> { - let user = match User::from_context(&ctx).await { - Ok(user) => user, - Err(_) => { - ctx.say("You are not an user in my database, sorry!") - .await?; - return Ok(()); - } - }; + let user = get_user!(&ctx); if ids.len() == 0 { ctx.reply("Empty ID list provided").await?; @@ -383,14 +200,14 @@ pub async fn show( for (i, id) in ids.iter().enumerate() { let res = query!( " - SELECT - p.id,p.chart_id,p.user_id,p.score,p.zeta_score, - p.max_recall,p.created_at,p.far_notes, - u.discord_id - FROM plays p - JOIN users u ON p.user_id = u.id - WHERE p.id=? - ", + SELECT + p.id,p.chart_id,p.user_id,p.score,p.zeta_score, + p.max_recall,p.created_at,p.far_notes, + u.discord_id + FROM plays p + JOIN users u ON p.user_id = u.id + WHERE p.id=? + ", id ) .fetch_one(&ctx.data().db) diff --git a/src/commands/stats.rs b/src/commands/stats.rs index dfdd389..872905a 100644 --- a/src/commands/stats.rs +++ b/src/commands/stats.rs @@ -17,17 +17,20 @@ use poise::{ use sqlx::query_as; use crate::{ + arcaea::chart::{Chart, Song}, + arcaea::jacket::BITMAP_IMAGE_SIZE, + arcaea::play::{DbPlay, Play}, + arcaea::score::Score, assets::{ get_b30_background, get_count_background, get_difficulty_background, get_grade_background, get_name_backgound, get_ptt_emblem, get_score_background, get_status_background, get_top_backgound, EXO_FONT, }, bitmap::{Align, BitmapCanvas, Color, LayoutDrawer, LayoutManager, Rect}, - chart::{Chart, Song}, context::{Context, Error}, - jacket::BITMAP_IMAGE_SIZE, - score::{guess_song_and_chart, DbPlay, Play, Score}, - user::{discord_it_to_discord_user, User}, + get_user, + recognition::fuzzy_song_name::guess_song_and_chart, + user::discord_it_to_discord_user, }; // {{{ Stats @@ -63,14 +66,7 @@ pub async fn best( #[description = "Name of chart to show (difficulty at the end)"] name: String, ) -> Result<(), Error> { - let user = match User::from_context(&ctx).await { - Ok(user) => user, - Err(_) => { - ctx.say("You are not an user in my database, sorry!") - .await?; - return Ok(()); - } - }; + let user = get_user!(&ctx); let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; let play = query_as!( @@ -121,14 +117,7 @@ pub async fn plot( #[description = "Name of chart to show (difficulty at the end)"] name: String, ) -> Result<(), Error> { - let user = match User::from_context(&ctx).await { - Ok(user) => user, - Err(_) => { - ctx.say("You are not an user in my database, sorry!") - .await?; - return Ok(()); - } - }; + let user = get_user!(&ctx); let (song, chart) = guess_song_and_chart(&ctx.data(), &name)?; @@ -240,14 +229,7 @@ pub async fn plot( /// Show the 30 best scores #[poise::command(prefix_command, slash_command)] pub async fn b30(ctx: Context<'_>) -> Result<(), Error> { - let user = match User::from_context(&ctx).await { - Ok(user) => user, - Err(_) => { - ctx.say("You are not an user in my database, sorry!") - .await?; - return Ok(()); - } - }; + let user = get_user!(&ctx); let plays: Vec = query_as( " diff --git a/src/commands/utils.rs b/src/commands/utils.rs new file mode 100644 index 0000000..21c81d9 --- /dev/null +++ b/src/commands/utils.rs @@ -0,0 +1,24 @@ +#[macro_export] +macro_rules! edit_reply { + ($ctx:expr, $handle:expr, $($arg:tt)*) => {{ + let content = format!($($arg)*); + let edited = CreateReply::default() + .reply(true) + .content(content); + $handle.edit($ctx, edited) + }}; +} + +#[macro_export] +macro_rules! get_user { + ($ctx:expr) => { + match crate::user::User::from_context($ctx).await { + Ok(user) => user, + Err(_) => { + $ctx.say("You are not an user in my database, sorry!") + .await?; + return Ok(()); + } + } + }; +} diff --git a/src/context.rs b/src/context.rs index 974c349..7198fef 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,7 +2,9 @@ use std::{fs, path::PathBuf}; use sqlx::SqlitePool; -use crate::{chart::SongCache, jacket::JacketCache, ocr::ui::UIMeasurements}; +use crate::{ + arcaea::chart::SongCache, arcaea::jacket::JacketCache, recognition::ui::UIMeasurements, +}; // Types used by all command functions pub type Error = Box; @@ -12,6 +14,7 @@ pub type Context<'a> = poise::Context<'a, UserContext, Error>; pub struct UserContext { #[allow(dead_code)] pub data_dir: PathBuf, + pub db: SqlitePool, pub song_cache: SongCache, pub jacket_cache: JacketCache, diff --git a/src/main.rs b/src/main.rs index 0406839..d2b113b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,17 +3,16 @@ #![feature(let_chains)] #![feature(array_try_map)] #![feature(async_closure)] +#![feature(try_blocks)] +mod arcaea; mod assets; mod bitmap; -mod chart; mod commands; mod context; -mod image; -mod jacket; mod levenshtein; -mod ocr; -mod score; +mod recognition; +mod transform; mod user; use assets::get_data_dir; diff --git a/src/ocr/mod.rs b/src/ocr/mod.rs deleted file mode 100644 index 6bae95d..0000000 --- a/src/ocr/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod ui; diff --git a/src/recognition/fuzzy_song_name.rs b/src/recognition/fuzzy_song_name.rs new file mode 100644 index 0000000..4e3f4ef --- /dev/null +++ b/src/recognition/fuzzy_song_name.rs @@ -0,0 +1,127 @@ +use crate::arcaea::chart::{Chart, Difficulty, Song, SongCache}; +use crate::context::{Error, UserContext}; +use crate::levenshtein::edit_distance_with; + +/// Similar to `.strip_suffix`, but case insensitive +#[inline] +fn strip_case_insensitive_suffix<'a>(string: &'a str, suffix: &str) -> Option<&'a str> { + let suffix = suffix.to_lowercase(); + if string.to_lowercase().ends_with(&suffix) { + Some(&string[0..string.len() - suffix.len()]) + } else { + None + } +} + +// {{{ Guess song and chart by name +pub fn guess_song_and_chart<'a>( + ctx: &'a UserContext, + name: &'a str, +) -> Result<(&'a Song, &'a Chart), Error> { + let name = name.trim(); + let (name, difficulty) = name + .strip_suffix("PST") + .zip(Some(Difficulty::PST)) + .or_else(|| strip_case_insensitive_suffix(name, "[PST]").zip(Some(Difficulty::PST))) + .or_else(|| strip_case_insensitive_suffix(name, "PRS").zip(Some(Difficulty::PRS))) + .or_else(|| strip_case_insensitive_suffix(name, "[PRS]").zip(Some(Difficulty::PRS))) + .or_else(|| strip_case_insensitive_suffix(name, "FTR").zip(Some(Difficulty::FTR))) + .or_else(|| strip_case_insensitive_suffix(name, "[FTR]").zip(Some(Difficulty::FTR))) + .or_else(|| strip_case_insensitive_suffix(name, "ETR").zip(Some(Difficulty::ETR))) + .or_else(|| strip_case_insensitive_suffix(name, "[ETR]").zip(Some(Difficulty::ETR))) + .or_else(|| strip_case_insensitive_suffix(name, "BYD").zip(Some(Difficulty::BYD))) + .or_else(|| strip_case_insensitive_suffix(name, "[BYD]").zip(Some(Difficulty::BYD))) + .unwrap_or((&name, Difficulty::FTR)); + + guess_chart_name(name, &ctx.song_cache, Some(difficulty), true) +} +// }}} +// {{{ Guess chart by name +/// Runs a specialized fuzzy-search through all charts in the game. +/// +/// The `unsafe_heuristics` toggle increases the amount of resolvable queries, but might let in +/// some false positives. We turn it on for simple user-search commands, but disallow it for things +/// like OCR-generated text. +pub fn guess_chart_name<'a>( + raw_text: &str, + cache: &'a SongCache, + difficulty: Option, + unsafe_heuristics: bool, +) -> Result<(&'a Song, &'a Chart), Error> { + let raw_text = raw_text.trim(); // not quite raw 🤔 + let mut text: &str = &raw_text.to_lowercase(); + + // Cached vec used by the levenshtein distance function + let mut levenshtein_vec = Vec::with_capacity(20); + // Cached vec used to store distance calculations + let mut distance_vec = Vec::with_capacity(3); + + let (song, chart) = loop { + let mut close_enough: Vec<_> = cache + .songs() + .filter_map(|item| { + let song = &item.song; + let chart = if let Some(difficulty) = difficulty { + item.lookup(difficulty).ok()? + } else { + item.charts().next()? + }; + + let song_title = &song.lowercase_title; + distance_vec.clear(); + + let base_distance = edit_distance_with(&text, &song_title, &mut levenshtein_vec); + if base_distance < 1.max(song.title.len() / 3) { + distance_vec.push(base_distance * 10 + 2); + } + + let shortest_len = Ord::min(song_title.len(), text.len()); + if let Some(sliced) = &song_title.get(..shortest_len) + && (text.len() >= 6 || unsafe_heuristics) + { + let slice_distance = edit_distance_with(&text, sliced, &mut levenshtein_vec); + if slice_distance < 1 { + distance_vec.push(slice_distance * 10 + 3); + } + } + + if let Some(shorthand) = &chart.shorthand + && unsafe_heuristics + { + let short_distance = edit_distance_with(&text, shorthand, &mut levenshtein_vec); + if short_distance < 1.max(shorthand.len() / 3) { + distance_vec.push(short_distance * 10 + 1); + } + } + + distance_vec + .iter() + .min() + .map(|distance| (song, chart, *distance)) + }) + .collect(); + + if close_enough.len() == 0 { + if text.len() <= 1 { + Err(format!( + "Could not find match for chart name '{}' [{:?}]", + raw_text, difficulty + ))?; + } else { + text = &text[..text.len() - 1]; + } + } else if close_enough.len() == 1 { + break (close_enough[0].0, close_enough[0].1); + } else { + if unsafe_heuristics { + close_enough.sort_by_key(|(_, _, distance)| *distance); + break (close_enough[0].0, close_enough[0].1); + } else { + return Err(format!("Name '{}' is too vague to choose a match", raw_text).into()); + }; + }; + }; + + Ok((song, chart)) +} +// }}} diff --git a/src/recognition/mod.rs b/src/recognition/mod.rs new file mode 100644 index 0000000..1c552f5 --- /dev/null +++ b/src/recognition/mod.rs @@ -0,0 +1,3 @@ +pub mod fuzzy_song_name; +pub mod recognize; +pub mod ui; diff --git a/src/recognition/recognize.rs b/src/recognition/recognize.rs new file mode 100644 index 0000000..0864ef8 --- /dev/null +++ b/src/recognition/recognize.rs @@ -0,0 +1,495 @@ +use std::fmt::Display; +use std::io::Cursor; +use std::str::FromStr; +use std::{env, fs}; + +use hypertesseract::{PageSegMode, Tesseract}; +use image::{DynamicImage, GenericImageView}; +use image::{ImageBuffer, Rgba}; +use num::integer::Roots; +use poise::serenity_prelude::{CreateAttachment, CreateEmbed, CreateMessage, Timestamp}; + +use crate::arcaea::chart::{Chart, Difficulty, Song, DIFFICULTY_MENU_PIXEL_COLORS}; +use crate::arcaea::jacket::IMAGE_VEC_DIM; +use crate::arcaea::score::Score; +use crate::bitmap::{Color, Rect}; +use crate::context::{Context, Error, UserContext}; +use crate::levenshtein::edit_distance; +use crate::recognition::fuzzy_song_name::guess_chart_name; +use crate::recognition::ui::{ + ScoreScreenRect, SongSelectRect, UIMeasurementRect, UIMeasurementRect::*, +}; +use crate::transform::rotate; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ScoreKind { + SongSelect, + ScoreScreen, +} + +/// Caches a byte vector in order to prevent reallocation +#[derive(Debug, Clone, Default)] +pub struct ImageAnalyzer { + /// cached byte array + pub bytes: Vec, + + /// Last rect used to crop something + last_rect: Option<(UIMeasurementRect, Rect)>, +} + +impl ImageAnalyzer { + /// Similar to reinitializing this, but without deallocating memory + #[inline] + pub fn clear(&mut self) { + self.bytes.clear(); + self.last_rect = None; + } + + // {{{ Crop + pub fn crop_image_to_bytes(&mut self, image: &DynamicImage, rect: Rect) -> Result<(), Error> { + self.clear(); + let image = image.crop_imm(rect.x as u32, rect.y as u32, rect.width, rect.height); + let mut cursor = Cursor::new(&mut self.bytes); + image.write_to(&mut cursor, image::ImageFormat::Png)?; + + fs::write(format!("./logs/{}.png", Timestamp::now()), &self.bytes)?; + + Ok(()) + } + + #[inline] + pub fn crop(&mut self, image: &DynamicImage, rect: Rect) -> ImageBuffer, Vec> { + if env::var("SHIMMERING_DEBUG_IMGS") + .map(|s| s == "1") + .unwrap_or(false) + { + self.crop_image_to_bytes(image, rect).unwrap(); + } + + image + .crop_imm(rect.x as u32, rect.y as u32, rect.width, rect.height) + .to_rgba8() + } + + #[inline] + pub fn interp_crop( + &mut self, + ctx: &UserContext, + image: &DynamicImage, + ui_rect: UIMeasurementRect, + ) -> Result, Vec>, Error> { + let rect = ctx.ui_measurements.interpolate(ui_rect, image)?; + self.last_rect = Some((ui_rect, rect)); + Ok(self.crop(image, rect)) + } + // }}} + // {{{ Error handling + pub async fn send_discord_error( + &mut self, + ctx: Context<'_>, + image: &DynamicImage, + filename: &str, + err: impl Display, + ) -> Result<(), Error> { + let mut embed = CreateEmbed::default().description(format!( + "Nerdy info +``` +{} +```", + err + )); + + if let Some((ui_rect, rect)) = self.last_rect { + self.crop_image_to_bytes(image, rect)?; + + let bytes = std::mem::take(&mut self.bytes); + let error_attachement = CreateAttachment::bytes(bytes, filename); + + embed = embed.attachment(filename).title(format!( + "An error occurred, around the time I was extracting data for {ui_rect:?}" + )); + + let msg = CreateMessage::default().embed(embed); + ctx.channel_id() + .send_files(ctx.http(), [error_attachement], msg) + .await?; + } else { + embed = embed.title("An error occurred"); + + let msg = CreateMessage::default().embed(embed); + ctx.channel_id().send_files(ctx.http(), [], msg).await?; + } + + Ok(()) + } + // }}} + // {{{ Read score + pub fn read_score( + &mut self, + ctx: &UserContext, + note_count: Option, + image: &DynamicImage, + kind: ScoreKind, + ) -> Result, Error> { + let image = self.interp_crop( + ctx, + image, + if kind == ScoreKind::ScoreScreen { + ScoreScreen(ScoreScreenRect::Score) + } else { + SongSelect(SongSelectRect::Score) + }, + )?; + + let mut results = vec![]; + for mode in [ + PageSegMode::SingleWord, + PageSegMode::RawLine, + PageSegMode::SingleLine, + PageSegMode::SparseText, + PageSegMode::SingleBlock, + ] { + let result: Result<_, Error> = try { + // {{{ Read score using tesseract + let text = Tesseract::builder() + .language(hypertesseract::Language::English) + .whitelist_str("0123456789'/")? + .page_seg_mode(mode) + .assume_numeric_input() + .build()? + .load_image(&image)? + .recognize()? + .get_text()?; + + let text: String = text + .trim() + .chars() + .map(|char| if char == '/' { '7' } else { char }) + .filter(|char| *char != ' ' && *char != '\'') + .collect(); + + let score = u32::from_str_radix(&text, 10)?; + Score(score) + // }}} + }; + + match result { + Ok(result) => { + results.push(result.0); + } + Err(err) => { + println!("OCR score result error: {}", err); + } + } + } + + // {{{ Score correction + // The OCR sometimes fails to read "74" with the arcaea font, + // so we try to detect that and fix it + loop { + let old_stack_len = results.len(); + println!("Results {:?}", results); + results = results + .iter() + .flat_map(|result| { + // If the length is correct, we are good to go! + if *result >= 8_000_000 { + vec![*result] + } else { + let mut results = vec![]; + for i in [0, 1, 3, 4] { + let d = 10u32.pow(i); + if (*result / d) % 10 == 4 && (*result / d) % 100 != 74 { + let n = d * 10; + results.push((*result / n) * n * 10 + 7 * n + (*result % n)); + } + } + + results + } + }) + .collect(); + + if old_stack_len == results.len() { + break; + } + } + // }}} + // {{{ Return score if consensus exists + // 1. Discard scores that are known to be impossible + let mut results: Vec<_> = results + .into_iter() + .filter(|result| { + 8_000_000 <= *result + && *result <= 10_010_000 + && note_count + .map(|note_count| { + let (zeta, shinies, score_units) = Score(*result).analyse(note_count); + 8_000_000 <= zeta.0 + && zeta.0 <= 10_000_000 && shinies <= note_count + && score_units <= 2 * note_count + }) + .unwrap_or(true) + }) + .map(|r| Score(r)) + .collect(); + println!("Results {:?}", results); + + // 2. Look for consensus + for result in results.iter() { + if results.iter().filter(|e| **e == *result).count() > results.len() / 2 { + return Ok(vec![*result]); + } + } + // }}} + + // If there's no consensus, we return everything + results.sort(); + results.dedup(); + println!("Results {:?}", results); + + Ok(results) + } + // }}} + // {{{ Read difficulty + pub fn read_difficulty( + &mut self, + ctx: &UserContext, + image: &DynamicImage, + kind: ScoreKind, + ) -> Result { + if kind == ScoreKind::SongSelect { + let min = DIFFICULTY_MENU_PIXEL_COLORS + .iter() + .zip(Difficulty::DIFFICULTIES) + .min_by_key(|(c, d)| { + let rect = ctx + .ui_measurements + .interpolate( + SongSelect(match d { + Difficulty::PST => SongSelectRect::Past, + Difficulty::PRS => SongSelectRect::Present, + Difficulty::FTR => SongSelectRect::Future, + _ => SongSelectRect::Beyond, + }), + image, + ) + .unwrap(); + + // rect.width = 100; + // rect.height = 100; + // self.crop_image_to_bytes(image, rect).unwrap(); + + let image_color = image.get_pixel(rect.x as u32, rect.y as u32); + let image_color = Color::from_bytes(image_color.0); + + let distance = c.distance(image_color); + (distance * 10000.0) as u32 + }) + .unwrap(); + + return Ok(min.1); + } + + let mut ocr = Tesseract::builder() + .language(hypertesseract::Language::English) + .page_seg_mode(PageSegMode::RawLine) + .build()?; + + ocr.load_image(&self.interp_crop(ctx, image, ScoreScreen(ScoreScreenRect::Difficulty))?)? + .recognize()?; + + let text: &str = &ocr.get_text()?; + let text = text.trim().to_lowercase(); + + // let conf = t.mean_text_conf(); + // if conf < 10 && conf != 0 { + // Err(format!( + // "Difficulty text is not readable (confidence = {}, text = {}).", + // conf, text + // ))?; + // } + + let difficulty = Difficulty::DIFFICULTIES + .iter() + .zip(Difficulty::DIFFICULTY_STRINGS) + .min_by_key(|(_, difficulty_string)| edit_distance(difficulty_string, &text)) + .map(|(difficulty, _)| *difficulty) + .ok_or_else(|| format!("Unrecognised difficulty '{}'", text))?; + + Ok(difficulty) + } + // }}} + // {{{ Read score kind + pub fn read_score_kind( + &mut self, + ctx: &UserContext, + image: &DynamicImage, + ) -> Result { + let text = Tesseract::builder() + .language(hypertesseract::Language::English) + .page_seg_mode(PageSegMode::RawLine) + .build()? + .load_image(&self.interp_crop(ctx, image, PlayKind)?)? + .recognize()? + .get_text()? + .trim() + .to_string(); + + // let conf = t.mean_text_conf(); + // if conf < 10 && conf != 0 { + // Err(format!( + // "Score kind text is not readable (confidence = {}, text = {}).", + // conf, text + // ))?; + // } + + let result = if edit_distance(&text, "Result") < edit_distance(&text, "Select a song") { + ScoreKind::ScoreScreen + } else { + ScoreKind::SongSelect + }; + + Ok(result) + } + // }}} + // {{{ Read song + pub fn read_song<'a>( + &mut self, + ctx: &'a UserContext, + image: &DynamicImage, + difficulty: Difficulty, + ) -> Result<(&'a Song, &'a Chart), Error> { + let text = Tesseract::builder() + .language(hypertesseract::Language::English) + .page_seg_mode(PageSegMode::SingleLine) + .whitelist_str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.()- ")? + .build()? + .load_image(&self.interp_crop(ctx, image, ScoreScreen(ScoreScreenRect::Title))?)? + .recognize()? + .get_text()?; + + // let conf = t.mean_text_conf(); + // if conf < 20 && conf != 0 { + // Err(format!( + // "Title text is not readable (confidence = {}, text = {}).", + // conf, + // raw_text.trim() + // ))?; + // } + + guess_chart_name(&text, &ctx.song_cache, Some(difficulty), false) + } + // }}} + // {{{ Read jacket + pub async fn read_jacket<'a>( + &mut self, + ctx: &'a UserContext, + image: &mut DynamicImage, + kind: ScoreKind, + difficulty: Difficulty, + ) -> Result<(&'a Song, &'a Chart), Error> { + let rect = ctx.ui_measurements.interpolate( + if kind == ScoreKind::ScoreScreen { + ScoreScreen(ScoreScreenRect::Jacket) + } else { + SongSelect(SongSelectRect::Jacket) + }, + image, + )?; + + let cropped = if kind == ScoreKind::ScoreScreen { + image.view(rect.x as u32, rect.y as u32, rect.width, rect.height) + } else { + let angle = f32::atan2(rect.height as f32, rect.width as f32); + let side = rect.height + rect.width; + rotate( + image, + Rect::new(rect.x, rect.y, side, side), + (rect.x, rect.y + rect.height as i32), + angle, + ); + + let len = (rect.width.pow(2) + rect.height.pow(2)).sqrt(); + + image.view(rect.x as u32, rect.y as u32 + rect.height, len, len) + }; + let (distance, song_id) = ctx + .jacket_cache + .recognise(&*cropped) + .ok_or_else(|| "Could not recognise jacket")?; + + if distance > (IMAGE_VEC_DIM * 3) as f32 { + Err("No known jacket looks like this")?; + } + + let item = ctx.song_cache.lookup(*song_id)?; + let chart = item.lookup(difficulty)?; + + // NOTE: this will reallocate a few strings, but it is what it is + Ok((&item.song, chart)) + } + // }}} + // {{{ Read distribution + pub fn read_distribution( + &mut self, + ctx: &UserContext, + image: &DynamicImage, + ) -> Result<(u32, u32, u32), Error> { + let mut ocr = Tesseract::builder() + .language(hypertesseract::Language::English) + .page_seg_mode(PageSegMode::SparseText) + .whitelist_str("0123456789")? + .assume_numeric_input() + .build()?; + + let mut out = [0; 3]; + + use ScoreScreenRect::*; + static KINDS: [ScoreScreenRect; 3] = [Pure, Far, Lost]; + + for i in 0..3 { + let text = ocr + .load_image(&self.interp_crop(ctx, image, ScoreScreen(KINDS[i]))?)? + .recognize()? + .get_text()?; + + println!("Raw '{}'", text.trim()); + out[i] = u32::from_str(&text.trim()).unwrap_or(0); + } + println!("Ditribution {out:?}"); + + Ok((out[0], out[1], out[2])) + } + // }}} + // {{{ Read max recall + pub fn read_max_recall<'a>( + &mut self, + ctx: &'a UserContext, + image: &DynamicImage, + ) -> Result { + let text = Tesseract::builder() + .language(hypertesseract::Language::English) + .page_seg_mode(PageSegMode::SingleLine) + .whitelist_str("0123456789")? + .assume_numeric_input() + .build()? + .load_image(&self.interp_crop(ctx, image, ScoreScreen(ScoreScreenRect::MaxRecall))?)? + .recognize()? + .get_text()?; + + let max_recall = u32::from_str_radix(text.trim(), 10)?; + + // let conf = t.mean_text_conf(); + // if conf < 20 && conf != 0 { + // Err(format!( + // "Title text is not readable (confidence = {}, text = {}).", + // conf, + // raw_text.trim() + // ))?; + // } + + Ok(max_recall) + } + // }}} +} diff --git a/src/ocr/ui.rs b/src/recognition/ui.rs similarity index 99% rename from src/ocr/ui.rs rename to src/recognition/ui.rs index 3a9a670..2c6061a 100644 --- a/src/ocr/ui.rs +++ b/src/recognition/ui.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use std::{fs, path::PathBuf}; use image::GenericImage; diff --git a/src/score.rs b/src/score.rs deleted file mode 100644 index 8b44321..0000000 --- a/src/score.rs +++ /dev/null @@ -1,1235 +0,0 @@ -#![allow(dead_code)] -use std::fmt::Display; -use std::fs; -use std::io::Cursor; -use std::str::FromStr; - -use image::{imageops::FilterType, DynamicImage, GenericImageView}; -use num::integer::Roots; -use num::{traits::Euclid, Rational64}; -use poise::serenity_prelude::{ - Attachment, AttachmentId, CreateAttachment, CreateEmbed, CreateEmbedAuthor, Timestamp, -}; -use sqlx::{query_as, SqlitePool}; -use tesseract::{PageSegMode, Tesseract}; - -use crate::bitmap::{Color, Rect}; -use crate::chart::{Chart, Difficulty, Song, SongCache, DIFFICULTY_MENU_PIXEL_COLORS}; -use crate::context::{Error, UserContext}; -use crate::image::rotate; -use crate::jacket::IMAGE_VEC_DIM; -use crate::levenshtein::{edit_distance, edit_distance_with}; -use crate::ocr::ui::{ScoreScreenRect, SongSelectRect, UIMeasurementRect}; -use crate::user::User; - -// {{{ Grade -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum Grade { - EXP, - EX, - AA, - A, - B, - C, - D, -} - -impl Grade { - pub const GRADE_STRINGS: [&'static str; 7] = ["EX+", "EX", "AA", "A", "B", "C", "D"]; - pub const GRADE_SHORTHANDS: [&'static str; 7] = ["exp", "ex", "aa", "a", "b", "c", "d"]; - - #[inline] - pub fn to_index(self) -> usize { - self as usize - } -} - -impl Display for Grade { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", Self::GRADE_STRINGS[self.to_index()]) - } -} -// }}} -// {{{ Score -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct Score(pub u32); - -impl Score { - // {{{ Score analysis - // {{{ Mini getters - #[inline] - pub fn to_zeta(self, note_count: u32) -> Score { - self.analyse(note_count).0 - } - - #[inline] - pub fn shinies(self, note_count: u32) -> u32 { - self.analyse(note_count).1 - } - - #[inline] - pub fn units(self, note_count: u32) -> u32 { - self.analyse(note_count).2 - } - // }}} - - #[inline] - pub fn increment(note_count: u32) -> Rational64 { - Rational64::new_raw(5_000_000, note_count as i64).reduced() - } - - /// Remove the contribution made by shinies to a score. - #[inline] - pub fn forget_shinies(self, note_count: u32) -> Self { - Self( - (Self::increment(note_count) * Rational64::from_integer(self.units(note_count) as i64)) - .floor() - .to_integer() as u32, - ) - } - - /// Compute a score without making a distinction between shinies and pures. That is, the given - /// value for `pures` must refer to the sum of `pure` and `shiny` notes. - /// - /// This is the simplest way to compute a score, and is useful for error analysis. - #[inline] - pub fn compute_naive(note_count: u32, pures: u32, fars: u32) -> Self { - Self( - (Self::increment(note_count) * Rational64::from_integer((2 * pures + fars) as i64)) - .floor() - .to_integer() as u32, - ) - } - - /// Returns the zeta score, the number of shinies, and the number of score units. - /// - /// Pure (and higher) notes reward two score units, far notes reward one, and lost notes reward - /// none. - pub fn analyse(self, note_count: u32) -> (Score, u32, u32) { - // Smallest possible difference between (zeta-)scores - let increment = Self::increment(note_count); - let zeta_increment = Rational64::new_raw(2_000_000, note_count as i64).reduced(); - - let score = Rational64::from_integer(self.0 as i64); - let score_units = (score / increment).floor(); - - let non_shiny_score = (score_units * increment).floor(); - let shinies = score - non_shiny_score; - - let zeta_score_units = Rational64::from_integer(2) * score_units + shinies; - let zeta_score = Score((zeta_increment * zeta_score_units).floor().to_integer() as u32); - - ( - zeta_score, - shinies.to_integer() as u32, - score_units.to_integer() as u32, - ) - } - // }}} - // {{{ Score => Play rating - #[inline] - pub fn play_rating(self, chart_constant: u32) -> i32 { - chart_constant as i32 - + if self.0 >= 10_000_000 { - 200 - } else if self.0 >= 9_800_000 { - 100 + (self.0 as i32 - 9_800_000) / 2_000 - } else { - (self.0 as i32 - 9_500_000) / 3_000 - } - } - - #[inline] - pub fn play_rating_f32(self, chart_constant: u32) -> f32 { - (self.play_rating(chart_constant)) as f32 / 100.0 - } - // }}} - // {{{ Score => grade - #[inline] - // TODO: Perhaps make an enum for this - pub fn grade(self) -> Grade { - let score = self.0; - if score > 9900000 { - Grade::EXP - } else if score > 9800000 { - Grade::EX - } else if score > 9500000 { - Grade::AA - } else if score > 9200000 { - Grade::A - } else if score > 8900000 { - Grade::B - } else if score > 8600000 { - Grade::C - } else { - Grade::D - } - } - // }}} - // {{{ Scores & Distribution => score - pub fn resolve_ambiguities( - scores: Vec, - read_distribution: Option<(u32, u32, u32)>, - note_count: u32, - ) -> Result<(Score, Option, Option<&'static str>), Error> { - if scores.len() == 0 { - return Err("No scores in list to disambiguate from.")?; - } - - let mut no_shiny_scores: Vec<_> = scores - .iter() - .map(|score| score.forget_shinies(note_count)) - .collect(); - no_shiny_scores.sort(); - no_shiny_scores.dedup(); - - if let Some(read_distribution) = read_distribution { - let pures = read_distribution.0; - let fars = read_distribution.1; - let losts = read_distribution.2; - - // Compute score from note breakdown subpairs - let pf_score = Score::compute_naive(note_count, pures, fars); - let fl_score = Score::compute_naive( - note_count, - note_count.checked_sub(losts + fars).unwrap_or(0), - fars, - ); - let lp_score = Score::compute_naive( - note_count, - pures, - note_count.checked_sub(losts + pures).unwrap_or(0), - ); - - if no_shiny_scores.len() == 1 { - // {{{ Score is fixed, gotta figure out the exact distribution - let score = *scores.iter().max().unwrap(); - - // {{{ Look for consensus among recomputed scores - // Lemma: if two computed scores agree, then so will the third - let consensus_fars = if pf_score == fl_score { - Some(fars) - } else { - // Due to the above lemma, we know all three scores must be distinct by - // this point. - // - // Our strategy is to check which of the three scores agrees with the real - // score, and to then trust the `far` value that contributed to that pair. - let no_shiny_score = score.forget_shinies(note_count); - let pf_appears = no_shiny_score == pf_score; - let fl_appears = no_shiny_score == fl_score; - let lp_appears = no_shiny_score == lp_score; - - match (pf_appears, fl_appears, lp_appears) { - (true, false, false) => Some(fars), - (false, true, false) => Some(fars), - (false, false, true) => Some(note_count - pures - losts), - _ => None, - } - }; - // }}} - - if scores.len() == 1 { - Ok((score, consensus_fars, None)) - } else { - Ok((score, consensus_fars, Some("Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!"))) - } - - // }}} - } else { - // {{{ Score is not fixed, gotta figure out everything at once - // Some of the values in the note distribution are likely wrong (due to reading - // errors). To get around this, we take each pair from the triplet, compute the score - // it induces, and figure out if there's any consensus as to which value in the - // provided score list is the real one. - // - // Note that sometimes the note distribution cannot resolve any of the issues. This is - // usually the case when the disagreement comes from the number of shinies. - - // {{{ Look for consensus among recomputed scores - // Lemma: if two computed scores agree, then so will the third - let (trusted_pure_count, consensus_computed_score, consensus_fars) = if pf_score - == fl_score - { - (true, pf_score, fars) - } else { - // Due to the above lemma, we know all three scores must be distinct by - // this point. - // - // Our strategy is to check which of the three scores appear in the - // provided score list. - let pf_appears = no_shiny_scores.contains(&pf_score); - let fl_appears = no_shiny_scores.contains(&fl_score); - let lp_appears = no_shiny_scores.contains(&lp_score); - - match (pf_appears, fl_appears, lp_appears) { - (true, false, false) => (true, pf_score, fars), - (false, true, false) => (false, fl_score, fars), - (false, false, true) => (true, lp_score, note_count - pures - losts), - _ => Err(format!("Cannot disambiguate scores {:?}. Multiple disjoint note breakdown subpair scores appear on the possibility list", scores))? - } - }; - // }}} - // {{{ Collect all scores that agree with the consensus score. - let agreement: Vec<_> = scores - .iter() - .filter(|score| score.forget_shinies(note_count) == consensus_computed_score) - .filter(|score| { - let shinies = score.shinies(note_count); - shinies <= note_count && (!trusted_pure_count || shinies <= pures) - }) - .map(|v| *v) - .collect(); - // }}} - // {{{ Case 1: Disagreement in the amount of shinies! - if agreement.len() > 1 { - let agreement_shiny_amounts: Vec<_> = - agreement.iter().map(|v| v.shinies(note_count)).collect(); - - println!( - "Shiny count disagreement. Possible scores: {:?}. Possible shiny amounts: {:?}, Read distribution: {:?}", - scores, agreement_shiny_amounts, read_distribution - ); - - let msg = Some( - "Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!" - ); - - Ok(( - agreement.into_iter().max().unwrap(), - Some(consensus_fars), - msg, - )) - // }}} - // {{{ Case 2: Total agreement! - } else if agreement.len() == 1 { - Ok((agreement[0], Some(consensus_fars), None)) - // }}} - // {{{ Case 3: No agreement! - } else { - Err(format!("Could not disambiguate between possible scores {:?}. Note distribution does not agree with any possibility, leading to a score of {:?}.", scores, consensus_computed_score))? - } - // }}} - // }}} - } - } else { - if no_shiny_scores.len() == 1 { - if scores.len() == 1 { - Ok((scores[0], None, None)) - } else { - Ok((scores.into_iter().max().unwrap(), None, Some("Due to a reading error, I could not make sure the shiny-amount I calculated is accurate!"))) - } - } else { - Err("Cannot disambiguate between more than one score without a note distribution.")? - } - } - } - // }}} -} - -impl Display for Score { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let score = self.0; - write!( - f, - "{}'{:0>3}'{:0>3}", - score / 1000000, - (score / 1000) % 1000, - score % 1000 - ) - } -} -// }}} -// {{{ Plays -// {{{ Create play -#[derive(Debug, Clone)] -pub struct CreatePlay { - chart_id: u32, - user_id: u32, - discord_attachment_id: Option, - - // Actual score data - score: Score, - zeta_score: Score, - - // Optional score details - max_recall: Option, - far_notes: Option, - - // Creation data - creation_ptt: Option, - creation_zeta_ptt: Option, -} - -impl CreatePlay { - #[inline] - pub fn new(score: Score, chart: &Chart, user: &User) -> Self { - Self { - chart_id: chart.id, - user_id: user.id, - discord_attachment_id: None, - score, - zeta_score: score.to_zeta(chart.note_count as u32), - max_recall: None, - far_notes: None, - // TODO: populate these - creation_ptt: None, - creation_zeta_ptt: None, - } - } - - #[inline] - pub fn with_attachment(mut self, attachment: &Attachment) -> Self { - self.discord_attachment_id = Some(attachment.id); - self - } - - #[inline] - pub fn with_fars(mut self, far_count: Option) -> Self { - self.far_notes = far_count; - self - } - - // {{{ Save - pub async fn save(self, ctx: &UserContext) -> Result { - let attachment_id = self.discord_attachment_id.map(|i| i.get() as i64); - let play = sqlx::query!( - " - INSERT INTO plays( - user_id,chart_id,discord_attachment_id, - score,zeta_score,max_recall,far_notes - ) - VALUES(?,?,?,?,?,?,?) - RETURNING id, created_at - ", - self.user_id, - self.chart_id, - attachment_id, - self.score.0, - self.zeta_score.0, - self.max_recall, - self.far_notes - ) - .fetch_one(&ctx.db) - .await?; - - Ok(Play { - id: play.id as u32, - created_at: play.created_at, - chart_id: self.chart_id, - user_id: self.user_id, - discord_attachment_id: self.discord_attachment_id, - score: self.score, - zeta_score: self.zeta_score, - max_recall: self.max_recall, - far_notes: self.far_notes, - creation_ptt: self.creation_ptt, - creation_zeta_ptt: self.creation_zeta_ptt, - }) - } - // }}} -} -// }}} -// {{{ DbPlay -/// Version of `Play` matching the format sqlx expects -#[derive(Debug, Clone, sqlx::FromRow)] -pub struct DbPlay { - pub id: i64, - pub chart_id: i64, - pub user_id: i64, - pub discord_attachment_id: Option, - pub score: i64, - pub zeta_score: i64, - pub max_recall: Option, - pub far_notes: Option, - pub created_at: chrono::NaiveDateTime, - pub creation_ptt: Option, - pub creation_zeta_ptt: Option, -} - -impl DbPlay { - #[inline] - pub fn to_play(self) -> Play { - Play { - id: self.id as u32, - chart_id: self.chart_id as u32, - user_id: self.user_id as u32, - score: Score(self.score as u32), - zeta_score: Score(self.zeta_score as u32), - max_recall: self.max_recall.map(|r| r as u32), - far_notes: self.far_notes.map(|r| r as u32), - created_at: self.created_at, - discord_attachment_id: self - .discord_attachment_id - .and_then(|s| AttachmentId::from_str(&s).ok()), - creation_ptt: self.creation_ptt.map(|r| r as u32), - creation_zeta_ptt: self.creation_zeta_ptt.map(|r| r as u32), - } - } -} -// }}} -// {{{ Play -#[derive(Debug, Clone)] -pub struct Play { - pub id: u32, - pub chart_id: u32, - pub user_id: u32, - pub discord_attachment_id: Option, - - // Actual score data - pub score: Score, - pub zeta_score: Score, - - // Optional score details - pub max_recall: Option, - pub far_notes: Option, - - // Creation data - pub created_at: chrono::NaiveDateTime, - pub creation_ptt: Option, - pub creation_zeta_ptt: Option, -} - -impl Play { - // {{{ Play => distribution - pub fn distribution(&self, note_count: u32) -> Option<(u32, u32, u32, u32)> { - if let Some(fars) = self.far_notes { - let (_, shinies, units) = self.score.analyse(note_count); - let (pures, rem) = units.checked_sub(fars)?.div_rem_euclid(&2); - if rem == 1 { - println!("The impossible happened: got an invalid amount of far notes!"); - return None; - } - - let lost = note_count.checked_sub(fars + pures)?; - let non_max_pures = pures.checked_sub(shinies)?; - Some((shinies, non_max_pures, fars, lost)) - } else { - None - } - } - // }}} - // {{{ Play => status - #[inline] - pub fn status(&self, chart: &Chart) -> Option { - let score = self.score.0; - if score >= 10_000_000 { - if score > chart.note_count + 10_000_000 { - return None; - } - - let non_max_pures = (chart.note_count + 10_000_000).checked_sub(score)?; - if non_max_pures == 0 { - Some("MPM".to_string()) - } else { - Some(format!("PM (-{})", non_max_pures)) - } - } else if let Some(distribution) = self.distribution(chart.note_count) { - // if no lost notes... - if distribution.3 == 0 { - Some(format!("FR (-{}/-{})", distribution.1, distribution.2)) - } else { - Some(format!( - "C (-{}/-{}/-{})", - distribution.1, distribution.2, distribution.3 - )) - } - } else { - None - } - } - - #[inline] - pub fn short_status(&self, chart: &Chart) -> Option { - let score = self.score.0; - if score >= 10_000_000 { - let non_max_pures = (chart.note_count + 10_000_000).checked_sub(score)?; - if non_max_pures == 0 { - Some('M') - } else { - Some('P') - } - } else if let Some(distribution) = self.distribution(chart.note_count) - && distribution.3 == 0 - { - Some('F') - } else { - Some('C') - } - } - // }}} - // {{{ Play to embed - /// Creates a discord embed for this play. - /// - /// The `index` variable is only used to create distinct filenames. - pub async fn to_embed( - &self, - db: &SqlitePool, - user: &User, - song: &Song, - chart: &Chart, - index: usize, - author: Option<&poise::serenity_prelude::User>, - ) -> Result<(CreateEmbed, Option), Error> { - // {{{ Get previously best score - let previously_best = query_as!( - DbPlay, - " - SELECT * FROM plays - WHERE user_id=? - AND chart_id=? - AND created_at Some(CreateAttachment::bytes(jacket.raw, &attachement_name)), - None => None, - }; - - let mut embed = CreateEmbed::default() - .title(format!( - "{} [{:?} {}]", - &song.title, chart.difficulty, chart.level - )) - .field("Score", format!("{} (+?)", self.score), true) - .field( - "Rating", - format!( - "{:.2} (+?)", - self.score.play_rating_f32(chart.chart_constant) - ), - true, - ) - .field("Grade", format!("{}", self.score.grade()), true) - .field("ξ-Score", format!("{} (+?)", self.zeta_score), true) - // {{{ ξ-Rating - .field( - "ξ-Rating", - { - let play_rating = self.zeta_score.play_rating_f32(chart.chart_constant); - if let Some(previous) = previously_best { - let previous_play_rating = - previous.zeta_score.play_rating_f32(chart.chart_constant); - - if play_rating >= previous_play_rating { - format!( - "{:.2} (+{})", - play_rating, - play_rating - previous_play_rating - ) - } else { - format!( - "{:.2} (-{})", - play_rating, - play_rating - previous_play_rating - ) - } - } else { - format!("{:.2}", play_rating) - } - }, - true, - ) - // }}} - .field("ξ-Grade", format!("{}", self.zeta_score.grade()), true) - .field( - "Status", - self.status(chart).unwrap_or("-".to_string()), - true, - ) - .field("Max recall", "—", true) - .field("ID", format!("{}", self.id), true); - - if icon_attachement.is_some() { - embed = embed.thumbnail(format!("attachment://{}", &attachement_name)); - } - - if let Some(user) = author { - let mut embed_author = CreateEmbedAuthor::new(&user.name); - if let Some(url) = user.avatar_url() { - embed_author = embed_author.icon_url(url); - } - - embed = embed - .timestamp(Timestamp::from_millis( - self.created_at.and_utc().timestamp_millis(), - )?) - .author(embed_author); - } - - Ok((embed, icon_attachement)) - } - // }}} -} -// }}} -// {{{ Tests -#[cfg(test)] -mod score_tests { - use super::*; - - #[test] - fn zeta_score_consistent_with_pms() { - // note counts - for note_count in 200..=2000 { - for shiny_count in 0..=note_count { - let score = Score(10000000 + shiny_count); - let zeta_score_units = 4 * (note_count - shiny_count) + 5 * shiny_count; - let (zeta_score, computed_shiny_count, units) = score.analyse(note_count); - let expected_zeta_score = Rational64::from_integer(zeta_score_units as i64) - * Rational64::new_raw(2000000, note_count as i64).reduced(); - - assert_eq!(zeta_score, Score(expected_zeta_score.to_integer() as u32)); - assert_eq!(computed_shiny_count, shiny_count); - assert_eq!(units, 2 * note_count); - } - } - } -} -// }}} -// }}} -// {{{ Score image kind -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ScoreKind { - SongSelect, - ScoreScreen, -} -// }}} -// {{{ Recognise chart -fn strip_case_insensitive_suffix<'a>(string: &'a str, suffix: &str) -> Option<&'a str> { - let suffix = suffix.to_lowercase(); - if string.to_lowercase().ends_with(&suffix) { - Some(&string[0..string.len() - suffix.len()]) - } else { - None - } -} - -pub fn guess_song_and_chart<'a>( - ctx: &'a UserContext, - name: &'a str, -) -> Result<(&'a Song, &'a Chart), Error> { - let name = name.trim(); - let (name, difficulty) = name - .strip_suffix("PST") - .zip(Some(Difficulty::PST)) - .or_else(|| strip_case_insensitive_suffix(name, "[PST]").zip(Some(Difficulty::PST))) - .or_else(|| strip_case_insensitive_suffix(name, "PRS").zip(Some(Difficulty::PRS))) - .or_else(|| strip_case_insensitive_suffix(name, "[PRS]").zip(Some(Difficulty::PRS))) - .or_else(|| strip_case_insensitive_suffix(name, "FTR").zip(Some(Difficulty::FTR))) - .or_else(|| strip_case_insensitive_suffix(name, "[FTR]").zip(Some(Difficulty::FTR))) - .or_else(|| strip_case_insensitive_suffix(name, "ETR").zip(Some(Difficulty::ETR))) - .or_else(|| strip_case_insensitive_suffix(name, "[ETR]").zip(Some(Difficulty::ETR))) - .or_else(|| strip_case_insensitive_suffix(name, "BYD").zip(Some(Difficulty::BYD))) - .or_else(|| strip_case_insensitive_suffix(name, "[BYD]").zip(Some(Difficulty::BYD))) - .unwrap_or((&name, Difficulty::FTR)); - - guess_chart_name(name, &ctx.song_cache, Some(difficulty), true) -} - -/// Runs a specialized fuzzy-search through all charts in the game. -/// -/// The `unsafe_heuristics` toggle increases the amount of resolvable queries, but might let in -/// some false positives. We turn it on for simple user-search commands, but disallow it for things -/// like OCR-generated text. -pub fn guess_chart_name<'a>( - raw_text: &str, - cache: &'a SongCache, - difficulty: Option, - unsafe_heuristics: bool, -) -> Result<(&'a Song, &'a Chart), Error> { - let raw_text = raw_text.trim(); // not quite raw 🤔 - let mut text: &str = &raw_text.to_lowercase(); - - // Cached vec used by the levenshtein distance function - let mut levenshtein_vec = Vec::with_capacity(20); - // Cached vec used to store distance calculations - let mut distance_vec = Vec::with_capacity(3); - - let (song, chart) = loop { - let mut close_enough: Vec<_> = cache - .songs() - .filter_map(|item| { - let song = &item.song; - let chart = if let Some(difficulty) = difficulty { - item.lookup(difficulty).ok()? - } else { - item.charts().next()? - }; - - let song_title = &song.lowercase_title; - distance_vec.clear(); - - let base_distance = edit_distance_with(&text, &song_title, &mut levenshtein_vec); - if base_distance < 1.max(song.title.len() / 3) { - distance_vec.push(base_distance * 10 + 2); - } - - let shortest_len = Ord::min(song_title.len(), text.len()); - if let Some(sliced) = &song_title.get(..shortest_len) - && (text.len() >= 6 || unsafe_heuristics) - { - let slice_distance = edit_distance_with(&text, sliced, &mut levenshtein_vec); - if slice_distance < 1 { - distance_vec.push(slice_distance * 10 + 3); - } - } - - if let Some(shorthand) = &chart.shorthand - && unsafe_heuristics - { - let short_distance = edit_distance_with(&text, shorthand, &mut levenshtein_vec); - if short_distance < 1.max(shorthand.len() / 3) { - distance_vec.push(short_distance * 10 + 1); - } - } - - distance_vec - .iter() - .min() - .map(|distance| (song, chart, *distance)) - }) - .collect(); - - if close_enough.len() == 0 { - if text.len() <= 1 { - Err(format!( - "Could not find match for chart name '{}' [{:?}]", - raw_text, difficulty - ))?; - } else { - text = &text[..text.len() - 1]; - } - } else if close_enough.len() == 1 { - break (close_enough[0].0, close_enough[0].1); - } else { - if unsafe_heuristics { - close_enough.sort_by_key(|(_, _, distance)| *distance); - break (close_enough[0].0, close_enough[0].1); - } else { - Err(format!( - "Name '{}' is too vague to choose a match", - raw_text - ))?; - }; - }; - }; - - // NOTE: this will reallocate a few strings, but it is what it is - Ok((song, chart)) -} -// }}} -// {{{ Run OCR -/// Caches a byte vector in order to prevent reallocation -#[derive(Debug, Clone, Default)] -pub struct ImageCropper { - /// cached byte array - pub bytes: Vec, -} - -impl ImageCropper { - pub fn crop_image_to_bytes(&mut self, image: &DynamicImage, rect: Rect) -> Result<(), Error> { - self.bytes.clear(); - let image = image.crop_imm(rect.x as u32, rect.y as u32, rect.width, rect.height); - let mut cursor = Cursor::new(&mut self.bytes); - image.write_to(&mut cursor, image::ImageFormat::Png)?; - - fs::write(format!("./logs/{}.png", Timestamp::now()), &self.bytes)?; - - Ok(()) - } - - // {{{ Read score - pub fn read_score( - &mut self, - ctx: &UserContext, - note_count: Option, - image: &DynamicImage, - kind: ScoreKind, - ) -> Result, Error> { - println!("kind {kind:?}"); - self.crop_image_to_bytes( - &image.resize_exact(image.width(), image.height(), FilterType::Nearest), - ctx.ui_measurements.interpolate( - if kind == ScoreKind::ScoreScreen { - UIMeasurementRect::ScoreScreen(ScoreScreenRect::Score) - } else { - UIMeasurementRect::SongSelect(SongSelectRect::Score) - }, - image, - )?, - )?; - - let mut results = vec![]; - for mode in [ - PageSegMode::PsmSingleWord, - PageSegMode::PsmRawLine, - PageSegMode::PsmSingleLine, - PageSegMode::PsmSparseText, - PageSegMode::PsmSingleBlock, - ] { - let result = self.read_score_with_mode(mode, "0123456789'/"); - match result { - Ok(result) => { - results.push(result.0); - } - Err(err) => { - println!("OCR score result error: {}", err); - } - } - } - - // {{{ Score correction - // The OCR sometimes fails to read "74" with the arcaea font, - // so we try to detect that and fix it - loop { - let old_stack_len = results.len(); - println!("Results {:?}", results); - results = results - .iter() - .flat_map(|result| { - // If the length is correct, we are good to go! - if *result >= 8_000_000 { - vec![*result] - } else { - let mut results = vec![]; - for i in [0, 1, 3, 4] { - let d = 10u32.pow(i); - if (*result / d) % 10 == 4 && (*result / d) % 100 != 74 { - let n = d * 10; - results.push((*result / n) * n * 10 + 7 * n + (*result % n)); - } - } - - results - } - }) - .collect(); - - if old_stack_len == results.len() { - break; - } - } - // }}} - // {{{ Return score if consensus exists - // 1. Discard scores that are known to be impossible - let mut results: Vec<_> = results - .into_iter() - .filter(|result| { - 8_000_000 <= *result - && *result <= 10_010_000 - && note_count - .map(|note_count| { - let (zeta, shinies, score_units) = Score(*result).analyse(note_count); - 8_000_000 <= zeta.0 - && zeta.0 <= 10_000_000 && shinies <= note_count - && score_units <= 2 * note_count - }) - .unwrap_or(true) - }) - .map(|r| Score(r)) - .collect(); - println!("Results {:?}", results); - - // 2. Look for consensus - for result in results.iter() { - if results.iter().filter(|e| **e == *result).count() > results.len() / 2 { - return Ok(vec![*result]); - } - } - // }}} - - // If there's no consensus, we return everything - results.sort(); - results.dedup(); - println!("Results {:?}", results); - - Ok(results) - } - - fn read_score_with_mode(&mut self, mode: PageSegMode, whitelist: &str) -> Result { - let mut t = Tesseract::new(None, Some("eng"))? - .set_variable("classify_bln_numeric_mode", "1")? - .set_variable("tessedit_char_whitelist", whitelist)? - .set_image_from_mem(&self.bytes)?; - t.set_page_seg_mode(mode); - t = t.recognize()?; - - // Disabled, as this was super unreliable - // let conf = t.mean_text_conf(); - // if conf < 10 && conf != 0 { - // Err(format!( - // "Score text is not readable (confidence = {}, text = {}).", - // conf, - // t.get_text()?.trim() - // ))?; - // } - - let text: String = t.get_text()?.trim().to_string(); - - let text: String = text - .chars() - .map(|char| if char == '/' { '7' } else { char }) - .filter(|char| *char != ' ' && *char != '\'') - .collect(); - - let score = u32::from_str_radix(&text, 10)?; - Ok(Score(score)) - } - // }}} - // {{{ Read difficulty - pub fn read_difficulty( - &mut self, - ctx: &UserContext, - image: &DynamicImage, - kind: ScoreKind, - ) -> Result { - if kind == ScoreKind::SongSelect { - let min = DIFFICULTY_MENU_PIXEL_COLORS - .iter() - .zip(Difficulty::DIFFICULTIES) - .min_by_key(|(c, d)| { - let rect = ctx - .ui_measurements - .interpolate( - UIMeasurementRect::SongSelect(match d { - Difficulty::PST => SongSelectRect::Past, - Difficulty::PRS => SongSelectRect::Present, - Difficulty::FTR => SongSelectRect::Future, - _ => SongSelectRect::Beyond, - }), - image, - ) - .unwrap(); - - let image_color = image.get_pixel(rect.x as u32, rect.y as u32); - let image_color = Color::from_bytes(image_color.0); - - let distance = c.distance(image_color); - (distance * 10000.0) as u32 - }) - .unwrap(); - - return Ok(min.1); - } - - self.crop_image_to_bytes( - image, - ctx.ui_measurements.interpolate( - UIMeasurementRect::ScoreScreen(ScoreScreenRect::Difficulty), - image, - )?, - )?; - - let mut t = Tesseract::new(None, Some("eng"))?.set_image_from_mem(&self.bytes)?; - t.set_page_seg_mode(PageSegMode::PsmRawLine); - t = t.recognize()?; - - let text: &str = &t.get_text()?; - let text = text.trim().to_lowercase(); - - let conf = t.mean_text_conf(); - if conf < 10 && conf != 0 { - Err(format!( - "Difficulty text is not readable (confidence = {}, text = {}).", - conf, text - ))?; - } - - let difficulty = Difficulty::DIFFICULTIES - .iter() - .zip(Difficulty::DIFFICULTY_STRINGS) - .min_by_key(|(_, difficulty_string)| edit_distance(difficulty_string, &text)) - .map(|(difficulty, _)| *difficulty) - .ok_or_else(|| format!("Unrecognised difficulty '{}'", text))?; - - Ok(difficulty) - } - // }}} - // {{{ Read score kind - pub fn read_score_kind( - &mut self, - ctx: &UserContext, - image: &DynamicImage, - ) -> Result { - self.crop_image_to_bytes( - &image, - ctx.ui_measurements - .interpolate(UIMeasurementRect::PlayKind, image)?, - )?; - - let mut t = Tesseract::new(None, Some("eng"))?.set_image_from_mem(&self.bytes)?; - t.set_page_seg_mode(PageSegMode::PsmRawLine); - t = t.recognize()?; - - let text: &str = &t.get_text()?; - let text = text.trim().to_lowercase(); - - let conf = t.mean_text_conf(); - if conf < 10 && conf != 0 { - Err(format!( - "Score kind text is not readable (confidence = {}, text = {}).", - conf, text - ))?; - } - - let result = if edit_distance(&text, "Result") < edit_distance(&text, "Select a song") { - ScoreKind::ScoreScreen - } else { - ScoreKind::SongSelect - }; - - Ok(result) - } - // }}} - // {{{ Read song - pub fn read_song<'a>( - &mut self, - ctx: &'a UserContext, - image: &DynamicImage, - difficulty: Difficulty, - ) -> Result<(&'a Song, &'a Chart), Error> { - self.crop_image_to_bytes( - &image, - ctx.ui_measurements.interpolate( - UIMeasurementRect::ScoreScreen(ScoreScreenRect::Title), - image, - )?, - )?; - - let mut t = Tesseract::new(None, Some("eng"))? - .set_variable( - "tessedit_char_whitelist", - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.()- ", - )? - .set_image_from_mem(&self.bytes)?; - t.set_page_seg_mode(PageSegMode::PsmSingleLine); - t = t.recognize()?; - - let raw_text: &str = &t.get_text()?; - - // let conf = t.mean_text_conf(); - // if conf < 20 && conf != 0 { - // Err(format!( - // "Title text is not readable (confidence = {}, text = {}).", - // conf, - // raw_text.trim() - // ))?; - // } - - guess_chart_name(raw_text, &ctx.song_cache, Some(difficulty), false) - } - // }}} - // {{{ Read jacket - pub async fn read_jacket<'a>( - &mut self, - ctx: &'a UserContext, - image: &mut DynamicImage, - kind: ScoreKind, - difficulty: Difficulty, - out_rect: &mut Option, - ) -> Result<(&'a Song, &'a Chart), Error> { - let rect = ctx.ui_measurements.interpolate( - if kind == ScoreKind::ScoreScreen { - UIMeasurementRect::ScoreScreen(ScoreScreenRect::Jacket) - } else { - UIMeasurementRect::SongSelect(SongSelectRect::Jacket) - }, - image, - )?; - - let cropped = if kind == ScoreKind::ScoreScreen { - *out_rect = Some(rect); - image.view(rect.x as u32, rect.y as u32, rect.width, rect.height) - } else { - let angle = f32::atan2(rect.height as f32, rect.width as f32); - let side = rect.height + rect.width; - rotate( - image, - Rect::new(rect.x, rect.y, side, side), - (rect.x, rect.y + rect.height as i32), - angle, - ); - - let len = (rect.width.pow(2) + rect.height.pow(2)).sqrt(); - - *out_rect = Some(Rect::new(rect.x, rect.y + rect.height as i32, len, len)); - image.view(rect.x as u32, rect.y as u32 + rect.height, len, len) - }; - let (distance, song_id) = ctx - .jacket_cache - .recognise(&*cropped) - .ok_or_else(|| "Could not recognise jacket")?; - - if distance > (IMAGE_VEC_DIM * 3) as f32 { - Err("No known jacket looks like this")?; - } - - let item = ctx.song_cache.lookup(*song_id)?; - let chart = item.lookup(difficulty)?; - - // NOTE: this will reallocate a few strings, but it is what it is - Ok((&item.song, chart)) - } - // }}} - // {{{ Read distribution - pub fn read_distribution( - &mut self, - ctx: &UserContext, - image: &DynamicImage, - ) -> Result<(u32, u32, u32), Error> { - let mut t = Tesseract::new(None, Some("eng"))? - .set_variable("classify_bln_numeric_mode", "1")? - .set_variable("tessedit_char_whitelist", "0123456789")?; - t.set_page_seg_mode(PageSegMode::PsmSparseText); - - self.crop_image_to_bytes( - &image, - ctx.ui_measurements - .interpolate(UIMeasurementRect::ScoreScreen(ScoreScreenRect::Pure), image)?, - )?; - - t = t.set_image_from_mem(&self.bytes)?.recognize()?; - let pure_notes = u32::from_str(&t.get_text()?.trim()).unwrap_or(0); - println!("Raw {}", t.get_text()?.trim()); - - self.crop_image_to_bytes( - &image, - ctx.ui_measurements - .interpolate(UIMeasurementRect::ScoreScreen(ScoreScreenRect::Far), image)?, - )?; - - t = t.set_image_from_mem(&self.bytes)?.recognize()?; - let far_notes = u32::from_str(&t.get_text()?.trim()).unwrap_or(0); - println!("Raw {}", t.get_text()?.trim()); - - self.crop_image_to_bytes( - &image, - ctx.ui_measurements - .interpolate(UIMeasurementRect::ScoreScreen(ScoreScreenRect::Lost), image)?, - )?; - - t = t.set_image_from_mem(&self.bytes)?.recognize()?; - let lost_notes = u32::from_str(&t.get_text()?.trim()).unwrap_or(0); - println!("Raw {}", t.get_text()?.trim()); - - Ok((pure_notes, far_notes, lost_notes)) - } - // }}} -} -// }}} diff --git a/src/image.rs b/src/transform.rs similarity index 100% rename from src/image.rs rename to src/transform.rs