forked from asklyphe-public/asklyphe
		
	Compare commits
	
		
			2 commits
		
	
	
		
			7e7079dd42
			...
			623b068cef
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 623b068cef | |||
| 270698c762 | 
					 1 changed files with 97 additions and 17 deletions
				
			
		| 
						 | 
					@ -1,36 +1,52 @@
 | 
				
			||||||
 | 
					use once_cell::sync::Lazy;
 | 
				
			||||||
use tracing::{debug, error};
 | 
					use tracing::{debug, error};
 | 
				
			||||||
use std::{cmp, mem};
 | 
					use std::{cmp, mem};
 | 
				
			||||||
// TODO: cache distances of strings/substrings
 | 
					use std::collections::BTreeMap;
 | 
				
			||||||
// TODO: use binary search to find direct matches, and if that fails, calculate and cache the result in BTreeMap<word: String, closest_match: String>
 | 
					use std::sync::Mutex;
 | 
				
			||||||
// TODO: limit by number of words and word length, not max chars, and use code more like this for better readability & async:
 | 
					 | 
				
			||||||
/*
 | 
					 | 
				
			||||||
	let words = prepare(query).split_whitespace()
 | 
					 | 
				
			||||||
		.filter(|qword| qword.len() > 0)
 | 
					 | 
				
			||||||
		.map(|qword| qword.to_lowercase());
 | 
					 | 
				
			||||||
	for word in words { // it might need to be while let Some(word) = words.next()
 | 
					 | 
				
			||||||
		tokio::spawn(levenshtein_distance(...))
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// how to generate words.txt:
 | 
				
			||||||
 | 
					// clone https://github.com/en-wl/wordlist && cd wordlist
 | 
				
			||||||
 | 
					// ./scowl wl --deaccent > words0.txt
 | 
				
			||||||
 | 
					// filtered with this python script:
 | 
				
			||||||
 | 
					// -----------------------------------
 | 
				
			||||||
 | 
					// with open("words0.txt", "r") as f:
 | 
				
			||||||
 | 
					// 	out = []
 | 
				
			||||||
 | 
					// 	for line in f:
 | 
				
			||||||
 | 
					// 		line = line.lower()
 | 
				
			||||||
 | 
					// 		if not line in out:
 | 
				
			||||||
 | 
					// 			out.append(line)
 | 
				
			||||||
 | 
					// 	out.sort()
 | 
				
			||||||
 | 
					// 	with open("words.txt", "w") as out_file:
 | 
				
			||||||
 | 
					// 		for line in out:
 | 
				
			||||||
 | 
					// 			out_file.write(f'{line}')
 | 
				
			||||||
 | 
					// ------------------------------------
 | 
				
			||||||
 | 
					// then use regex or similar to enclose every line in quotes and add comma, then add 'static KNOWN_WORDS: &[&str] = &[' to the start and '];' to the end
 | 
				
			||||||
include!("./words.txt");
 | 
					include!("./words.txt");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// a cache of misspelled words and the closest match in the database
 | 
				
			||||||
 | 
					static MATCH_CACHE: Lazy<Mutex<BTreeMap<String, Option<&str>>>> = Lazy::new(|| Mutex::new(BTreeMap::new()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// max distance before no alternatives are considered
 | 
					// max distance before no alternatives are considered
 | 
				
			||||||
const MAX_DISTANCE: usize = 6;
 | 
					const MAX_DISTANCE: usize = 6;
 | 
				
			||||||
// max input text size before spellcheck is not run. on my laptop 13,000 chars of input takes around 4 seconds so this should be fine
 | 
					// max input text size before spellcheck is not run. on my laptop 13,000 chars of input takes around 4 seconds so this should be fine
 | 
				
			||||||
// update: got a larger word database and it doesn't take 4 seconds anymore lmao
 | 
					// update: got a larger word database and it doesn't take 4 seconds anymore lmao
 | 
				
			||||||
const MAX_QUERY_SIZE: usize = 1024;
 | 
					// update 2: added binary search & caching and now 50000 chars takes ~2-4 seconds
 | 
				
			||||||
 | 
					const MAX_QUERY_WORDS: usize = 512;
 | 
				
			||||||
 | 
					// Not really a huge issue, just used to hopefully reduce the allocations made in levenshtein_distance & provide minor performance improvements
 | 
				
			||||||
 | 
					// not needed for now
 | 
				
			||||||
 | 
					// const MAX_WORD_SIZE: usize = 64;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub type SpellCheckResults = Vec<SpellCheckResult>;
 | 
					pub type SpellCheckResults = Vec<SpellCheckResult>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[derive(Debug)]
 | 
					#[derive(Debug)]
 | 
				
			||||||
pub struct SpellCheckResult {
 | 
					pub struct SpellCheckResult {
 | 
				
			||||||
	pub orig: String,
 | 
						pub orig: String,
 | 
				
			||||||
	pub correction: String,
 | 
						pub correction: &'static str,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub fn check(query: &String) -> Option<SpellCheckResults> {
 | 
					pub fn check(query: &String) -> Option<SpellCheckResults> {
 | 
				
			||||||
	error!("Query: {}", query);
 | 
						error!("Query: {}", query);
 | 
				
			||||||
	let query: &str = {
 | 
						/*let query: &str = {
 | 
				
			||||||
		if query.len() > MAX_QUERY_SIZE {
 | 
							if query.len() > MAX_QUERY_SIZE {
 | 
				
			||||||
			error!("Query is too large to be spell checked, only checking first {} chars", MAX_QUERY_SIZE);
 | 
								error!("Query is too large to be spell checked, only checking first {} chars", MAX_QUERY_SIZE);
 | 
				
			||||||
			query.get(0..MAX_QUERY_SIZE).unwrap()
 | 
								query.get(0..MAX_QUERY_SIZE).unwrap()
 | 
				
			||||||
| 
						 | 
					@ -38,9 +54,66 @@ pub fn check(query: &String) -> Option<SpellCheckResults> {
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			query
 | 
								query
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	};
 | 
						};*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	let distances = prepare(query).split_whitespace()
 | 
						// TODO: look into how 'wc -w' counts words and copy how it splits things
 | 
				
			||||||
 | 
						let query_flattened = prepare(query);
 | 
				
			||||||
 | 
						let words = query_flattened
 | 
				
			||||||
 | 
							.split_whitespace()
 | 
				
			||||||
 | 
							.filter(|word| word.len() > 0)
 | 
				
			||||||
 | 
							// .filter(|word|)
 | 
				
			||||||
 | 
							.collect::<Vec<_>>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						error!("Words in query: {}", words.len());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if (words.len() > MAX_QUERY_WORDS) {
 | 
				
			||||||
 | 
							error!("{} is too many words in query to spell check", words.len());
 | 
				
			||||||
 | 
							// return None;
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						let mut distances: SpellCheckResults = vec![];
 | 
				
			||||||
 | 
						for qword in words {
 | 
				
			||||||
 | 
							// error!("Word: {}", qword);
 | 
				
			||||||
 | 
							// error!("is known: {:?}", KNOWN_WORDS.binary_search(&qword));
 | 
				
			||||||
 | 
							if KNOWN_WORDS.binary_search(&qword).is_ok() {
 | 
				
			||||||
 | 
								// error!("Exact word match: {}", qword);
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								let mut cache = MATCH_CACHE.lock().unwrap();
 | 
				
			||||||
 | 
								if cache.contains_key(qword) {
 | 
				
			||||||
 | 
									// We don't need to tell the user if there is no suggestion for an unknown word
 | 
				
			||||||
 | 
									if (cache.get(qword).unwrap().is_some()) {
 | 
				
			||||||
 | 
										// TODO: don't push duplicate misspelled words
 | 
				
			||||||
 | 
										distances.push(SpellCheckResult{orig: qword.to_owned(), correction: cache.get(qword).unwrap().unwrap()});
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									let closest_match = KNOWN_WORDS.iter()
 | 
				
			||||||
 | 
										.map(|kword| (kword, levenshtein_distance(&qword, &kword)))
 | 
				
			||||||
 | 
										.min_by(|a, b| a.1.cmp(&b.1)).unwrap();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									assert!(closest_match.1 > 0, "Found exact match not caught by binary search, is the word database properly sorted?");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									if closest_match.1 <= MAX_DISTANCE {
 | 
				
			||||||
 | 
										cache.insert(qword.to_owned(), Some(*closest_match.0));
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										// even though there is no close enough match, cache it anyway so that it doesn't have to be looked up every time
 | 
				
			||||||
 | 
										cache.insert(qword.to_owned(), None);
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// error!("End");
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						error!("Spell check results:");
 | 
				
			||||||
 | 
						for word in &distances {
 | 
				
			||||||
 | 
							debug!("instead of '{}' did you mean '{}'?", word.orig, word.correction);
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if distances.len() > 0 {
 | 
				
			||||||
 | 
							Some(distances)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							None
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*	let distances = prepare(query).split_whitespace()
 | 
				
			||||||
		.filter(|qword| qword.len() > 0)
 | 
							.filter(|qword| qword.len() > 0)
 | 
				
			||||||
		.map(|qword| qword.to_lowercase())
 | 
							.map(|qword| qword.to_lowercase())
 | 
				
			||||||
		.map(
 | 
							.map(
 | 
				
			||||||
| 
						 | 
					@ -69,7 +142,8 @@ pub fn check(query: &String) -> Option<SpellCheckResults> {
 | 
				
			||||||
		Some(distances)
 | 
							Some(distances)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		None
 | 
							None
 | 
				
			||||||
	}
 | 
						}*/
 | 
				
			||||||
 | 
						// None
 | 
				
			||||||
	// vec![]
 | 
						// vec![]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -94,11 +168,15 @@ fn prepare(s: &str) -> String {
 | 
				
			||||||
		.replace("7", "")
 | 
							.replace("7", "")
 | 
				
			||||||
		.replace("8", "")
 | 
							.replace("8", "")
 | 
				
			||||||
		.replace("9", "")
 | 
							.replace("9", "")
 | 
				
			||||||
 | 
							.to_lowercase()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// cost of 2 for add/remove, cost of 1 for replace
 | 
					// cost of 2 for add/remove, cost of 1 for replace
 | 
				
			||||||
fn levenshtein_distance(a: &str, other: &str) -> usize {
 | 
					fn levenshtein_distance(a: &str, other: &str) -> usize {
 | 
				
			||||||
	// debug!("Self: '{}', Other: '{}'", a, other);
 | 
						// debug!("Self: '{}', Other: '{}'", a, other);
 | 
				
			||||||
 | 
						// let mut dist: &mut [usize; MAX_WORD_SIZE] = &mut [0usize; MAX_WORD_SIZE];
 | 
				
			||||||
 | 
						// let mut dist_prev: &mut [usize; MAX_WORD_SIZE] = &mut [0usize; MAX_WORD_SIZE];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	let mut dist = vec![0usize; other.len() + 1];
 | 
						let mut dist = vec![0usize; other.len() + 1];
 | 
				
			||||||
	let mut dist_prev = vec![0usize; other.len() + 1];
 | 
						let mut dist_prev = vec![0usize; other.len() + 1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -113,6 +191,8 @@ fn levenshtein_distance(a: &str, other: &str) -> usize {
 | 
				
			||||||
			if a.get(i - 1..i).unwrap() == other.get(j - 1..j).unwrap() {
 | 
								if a.get(i - 1..i).unwrap() == other.get(j - 1..j).unwrap() {
 | 
				
			||||||
				dist[j] = dist_prev[j - 1];
 | 
									dist[j] = dist_prev[j - 1];
 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
 | 
									// TODO: make addition/subtraction 1 more expensive than replacement, presumably by adding '+ 1' to 2/3 of these
 | 
				
			||||||
 | 
									// motivation: honex from bee movie script is turned into hone instead of honey, this will also generally improve results & is what wikipedia says to do (best reason)
 | 
				
			||||||
				dist[j] = 1 + cmp::min(
 | 
									dist[j] = 1 + cmp::min(
 | 
				
			||||||
					dist.get(j - 1).unwrap(),
 | 
										dist.get(j - 1).unwrap(),
 | 
				
			||||||
					cmp::min(dist_prev.get(j).unwrap(), dist_prev.get(j - 1).unwrap()));
 | 
										cmp::min(dist_prev.get(j).unwrap(), dist_prev.get(j - 1).unwrap()));
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue