forked from asklyphe-public/asklyphe
implement caching spellcheck results & other stuff
This commit is contained in:
parent
7e7079dd42
commit
270698c762
1 changed files with 80 additions and 15 deletions
|
@ -1,36 +1,37 @@
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
use std::{cmp, mem};
|
use std::{cmp, mem};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::Mutex;
|
||||||
// TODO: cache distances of strings/substrings
|
// TODO: cache distances of strings/substrings
|
||||||
// TODO: use binary search to find direct matches, and if that fails, calculate and cache the result in BTreeMap<word: String, closest_match: String>
|
// TODO: use binary search to find direct matches, and if that fails, calculate and cache the result in BTreeMap<word: String, closest_match: String>
|
||||||
// TODO: limit by number of words and word length, not max chars, and use code more like this for better readability & async:
|
|
||||||
/*
|
|
||||||
let words = prepare(query).split_whitespace()
|
|
||||||
.filter(|qword| qword.len() > 0)
|
|
||||||
.map(|qword| qword.to_lowercase());
|
|
||||||
for word in words { // it might need to be while let Some(word) = words.next()
|
|
||||||
tokio::spawn(levenshtein_distance(...))
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
include!("./words.txt");
|
include!("./words.txt");
|
||||||
|
|
||||||
|
// a cache of misspelled words and the closest match in the database
|
||||||
|
static MATCH_CACHE: Lazy<Mutex<BTreeMap<String, Option<&str>>>> = Lazy::new(|| Mutex::new(BTreeMap::new()));
|
||||||
|
|
||||||
// max distance before no alternatives are considered
|
// max distance before no alternatives are considered
|
||||||
const MAX_DISTANCE: usize = 6;
|
const MAX_DISTANCE: usize = 6;
|
||||||
// max input text size before spellcheck is not run. on my laptop 13,000 chars of input takes around 4 seconds so this should be fine
|
// max input text size before spellcheck is not run. on my laptop 13,000 chars of input takes around 4 seconds so this should be fine
|
||||||
// update: got a larger word database and it doesn't take 4 seconds anymore lmao
|
// update: got a larger word database and it doesn't take 4 seconds anymore lmao
|
||||||
const MAX_QUERY_SIZE: usize = 1024;
|
// update 2: added binary search & caching and now 50000 chars takes ~2-4 seconds
|
||||||
|
const MAX_QUERY_WORDS: usize = 512;
|
||||||
|
// Not really a huge issue, just used to hopefully reduce the allocations made in levenshtein_distance & provide minor performance improvements
|
||||||
|
// not needed for now
|
||||||
|
// const MAX_WORD_SIZE: usize = 64;
|
||||||
|
|
||||||
pub type SpellCheckResults = Vec<SpellCheckResult>;
|
pub type SpellCheckResults = Vec<SpellCheckResult>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct SpellCheckResult {
|
pub struct SpellCheckResult {
|
||||||
pub orig: String,
|
pub orig: String,
|
||||||
pub correction: String,
|
pub correction: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check(query: &String) -> Option<SpellCheckResults> {
|
pub fn check(query: &String) -> Option<SpellCheckResults> {
|
||||||
error!("Query: {}", query);
|
error!("Query: {}", query);
|
||||||
let query: &str = {
|
/*let query: &str = {
|
||||||
if query.len() > MAX_QUERY_SIZE {
|
if query.len() > MAX_QUERY_SIZE {
|
||||||
error!("Query is too large to be spell checked, only checking first {} chars", MAX_QUERY_SIZE);
|
error!("Query is too large to be spell checked, only checking first {} chars", MAX_QUERY_SIZE);
|
||||||
query.get(0..MAX_QUERY_SIZE).unwrap()
|
query.get(0..MAX_QUERY_SIZE).unwrap()
|
||||||
|
@ -38,9 +39,66 @@ pub fn check(query: &String) -> Option<SpellCheckResults> {
|
||||||
} else {
|
} else {
|
||||||
query
|
query
|
||||||
}
|
}
|
||||||
};
|
};*/
|
||||||
|
|
||||||
let distances = prepare(query).split_whitespace()
|
// TODO: look into how 'wc -w' counts words and copy how it splits things
|
||||||
|
let query_flattened = prepare(query);
|
||||||
|
let words = query_flattened
|
||||||
|
.split_whitespace()
|
||||||
|
.filter(|word| word.len() > 0)
|
||||||
|
// .filter(|word|)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
error!("Words in query: {}", words.len());
|
||||||
|
|
||||||
|
if (words.len() > MAX_QUERY_WORDS) {
|
||||||
|
error!("{} is too many words in query to spell check", words.len());
|
||||||
|
// return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut distances: SpellCheckResults = vec![];
|
||||||
|
for qword in words {
|
||||||
|
// error!("Word: {}", qword);
|
||||||
|
// error!("is known: {:?}", KNOWN_WORDS.binary_search(&qword));
|
||||||
|
if KNOWN_WORDS.binary_search(&qword).is_ok() {
|
||||||
|
// error!("Exact word match: {}", qword);
|
||||||
|
} else {
|
||||||
|
let mut cache = MATCH_CACHE.lock().unwrap();
|
||||||
|
if cache.contains_key(qword) {
|
||||||
|
// We don't need to tell the user if there is no suggestion for an unknown word
|
||||||
|
if (cache.get(qword).unwrap().is_some()) {
|
||||||
|
// TODO: don't push duplicate misspelled words
|
||||||
|
distances.push(SpellCheckResult{orig: qword.to_owned(), correction: cache.get(qword).unwrap().unwrap()});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let closest_match = KNOWN_WORDS.iter()
|
||||||
|
.map(|kword| (kword, levenshtein_distance(&qword, &kword)))
|
||||||
|
.min_by(|a, b| a.1.cmp(&b.1)).unwrap();
|
||||||
|
|
||||||
|
assert!(closest_match.1 > 0, "Found exact match not caught by binary search, is the word database properly sorted?");
|
||||||
|
|
||||||
|
if closest_match.1 <= MAX_DISTANCE {
|
||||||
|
cache.insert(qword.to_owned(), Some(*closest_match.0));
|
||||||
|
} else {
|
||||||
|
// even though there is no close enough match, cache it anyway so that it doesn't have to be looked up every time
|
||||||
|
cache.insert(qword.to_owned(), None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// error!("End");
|
||||||
|
}
|
||||||
|
error!("Spell check results:");
|
||||||
|
for word in &distances {
|
||||||
|
debug!("instead of '{}' did you mean '{}'?", word.orig, word.correction);
|
||||||
|
}
|
||||||
|
|
||||||
|
if distances.len() > 0 {
|
||||||
|
Some(distances)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/* let distances = prepare(query).split_whitespace()
|
||||||
.filter(|qword| qword.len() > 0)
|
.filter(|qword| qword.len() > 0)
|
||||||
.map(|qword| qword.to_lowercase())
|
.map(|qword| qword.to_lowercase())
|
||||||
.map(
|
.map(
|
||||||
|
@ -69,7 +127,8 @@ pub fn check(query: &String) -> Option<SpellCheckResults> {
|
||||||
Some(distances)
|
Some(distances)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}*/
|
||||||
|
// None
|
||||||
// vec![]
|
// vec![]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,11 +153,15 @@ fn prepare(s: &str) -> String {
|
||||||
.replace("7", "")
|
.replace("7", "")
|
||||||
.replace("8", "")
|
.replace("8", "")
|
||||||
.replace("9", "")
|
.replace("9", "")
|
||||||
|
.to_lowercase()
|
||||||
}
|
}
|
||||||
|
|
||||||
// cost of 2 for add/remove, cost of 1 for replace
|
// cost of 2 for add/remove, cost of 1 for replace
|
||||||
fn levenshtein_distance(a: &str, other: &str) -> usize {
|
fn levenshtein_distance(a: &str, other: &str) -> usize {
|
||||||
// debug!("Self: '{}', Other: '{}'", a, other);
|
// debug!("Self: '{}', Other: '{}'", a, other);
|
||||||
|
// let mut dist: &mut [usize; MAX_WORD_SIZE] = &mut [0usize; MAX_WORD_SIZE];
|
||||||
|
// let mut dist_prev: &mut [usize; MAX_WORD_SIZE] = &mut [0usize; MAX_WORD_SIZE];
|
||||||
|
|
||||||
let mut dist = vec![0usize; other.len() + 1];
|
let mut dist = vec![0usize; other.len() + 1];
|
||||||
let mut dist_prev = vec![0usize; other.len() + 1];
|
let mut dist_prev = vec![0usize; other.len() + 1];
|
||||||
|
|
||||||
|
@ -113,6 +176,8 @@ fn levenshtein_distance(a: &str, other: &str) -> usize {
|
||||||
if a.get(i - 1..i).unwrap() == other.get(j - 1..j).unwrap() {
|
if a.get(i - 1..i).unwrap() == other.get(j - 1..j).unwrap() {
|
||||||
dist[j] = dist_prev[j - 1];
|
dist[j] = dist_prev[j - 1];
|
||||||
} else {
|
} else {
|
||||||
|
// TODO: make addition/subtraction 1 more expensive than replacement, presumably by adding '+ 1' to 2/3 of these
|
||||||
|
// motivation: honex from bee movie script is turned into hone instead of honey, this will also generally improve results & is what wikipedia says to do (best reason)
|
||||||
dist[j] = 1 + cmp::min(
|
dist[j] = 1 + cmp::min(
|
||||||
dist.get(j - 1).unwrap(),
|
dist.get(j - 1).unwrap(),
|
||||||
cmp::min(dist_prev.get(j).unwrap(), dist_prev.get(j - 1).unwrap()));
|
cmp::min(dist_prev.get(j).unwrap(), dist_prev.get(j - 1).unwrap()));
|
||||||
|
|
Loading…
Add table
Reference in a new issue