fix bugs in fuzzy matching

This commit is contained in:
Pascal Kuthe 2023-07-20 15:55:59 +02:00
parent e964d42849
commit 33822be2ab
No known key found for this signature in database
GPG Key ID: D715E8655AE166A6
8 changed files with 173 additions and 82 deletions

View File

@ -1,3 +1,5 @@
use std::fmt::{self, Debug, Display};
use crate::chars::case_fold::CASE_FOLDING_SIMPLE; use crate::chars::case_fold::CASE_FOLDING_SIMPLE;
use crate::MatcherConfig; use crate::MatcherConfig;
@ -7,18 +9,52 @@ use crate::MatcherConfig;
mod case_fold; mod case_fold;
mod normalize; mod normalize;
pub trait Char: Copy + Eq + Ord + std::fmt::Debug { pub trait Char: Copy + Eq + Ord + fmt::Debug + fmt::Display {
const ASCII: bool; const ASCII: bool;
fn char_class(self, config: &MatcherConfig) -> CharClass; fn char_class(self, config: &MatcherConfig) -> CharClass;
fn char_class_and_normalize(self, config: &MatcherConfig) -> (Self, CharClass); fn char_class_and_normalize(self, config: &MatcherConfig) -> (Self, CharClass);
fn normalize(self, config: &MatcherConfig) -> Self; fn normalize(self, config: &MatcherConfig) -> Self;
} }
impl Char for u8 { /// repr tansparent wrapper around u8 with better formatting and PartialEq<char> implementation
#[repr(transparent)]
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(crate) struct AsciiChar(u8);
impl AsciiChar {
pub fn cast(bytes: &[u8]) -> &[AsciiChar] {
unsafe { &*(bytes as *const [u8] as *const [AsciiChar]) }
}
}
impl fmt::Debug for AsciiChar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Debug::fmt(&(self.0 as char), f)
}
}
impl fmt::Display for AsciiChar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Display::fmt(&(self.0 as char), f)
}
}
impl PartialEq<char> for AsciiChar {
fn eq(&self, other: &char) -> bool {
self.0 as char == *other
}
}
impl PartialEq<AsciiChar> for char {
fn eq(&self, other: &AsciiChar) -> bool {
other.0 as char == *self
}
}
impl Char for AsciiChar {
const ASCII: bool = true; const ASCII: bool = true;
#[inline] #[inline]
fn char_class(self, config: &MatcherConfig) -> CharClass { fn char_class(self, config: &MatcherConfig) -> CharClass {
let c = self; let c = self.0;
// using manual if conditions instead optimizes better // using manual if conditions instead optimizes better
if c >= b'a' && c <= b'z' { if c >= b'a' && c <= b'z' {
CharClass::Lower CharClass::Lower
@ -36,23 +72,20 @@ impl Char for u8 {
} }
#[inline(always)] #[inline(always)]
fn char_class_and_normalize(self, config: &MatcherConfig) -> (Self, CharClass) { fn char_class_and_normalize(mut self, config: &MatcherConfig) -> (Self, CharClass) {
let char_class = self.char_class(config); let char_class = self.char_class(config);
let normalized = if config.ignore_case && char_class == CharClass::Upper { if config.ignore_case && char_class == CharClass::Upper {
self + 32 self.0 += 32
} else { }
self (self, char_class)
};
(normalized, char_class)
} }
#[inline(always)] #[inline(always)]
fn normalize(self, config: &MatcherConfig) -> Self { fn normalize(mut self, config: &MatcherConfig) -> Self {
if config.ignore_case && self >= b'A' && self <= b'Z' { if config.ignore_case && self.0 >= b'A' && self.0 <= b'Z' {
self + 32 self.0 += 32
} else {
self
} }
self
} }
} }
fn char_class_non_ascii(c: char) -> CharClass { fn char_class_non_ascii(c: char) -> CharClass {
@ -75,7 +108,7 @@ impl Char for char {
#[inline(always)] #[inline(always)]
fn char_class(self, config: &MatcherConfig) -> CharClass { fn char_class(self, config: &MatcherConfig) -> CharClass {
if self.is_ascii() { if self.is_ascii() {
return (self as u8).char_class(config); return AsciiChar(self as u8).char_class(config);
} }
char_class_non_ascii(self) char_class_non_ascii(self)
} }
@ -83,8 +116,8 @@ impl Char for char {
#[inline(always)] #[inline(always)]
fn char_class_and_normalize(mut self, config: &MatcherConfig) -> (Self, CharClass) { fn char_class_and_normalize(mut self, config: &MatcherConfig) -> (Self, CharClass) {
if self.is_ascii() { if self.is_ascii() {
let (c, class) = (self as u8).char_class_and_normalize(config); let (c, class) = AsciiChar(self as u8).char_class_and_normalize(config);
return (c as char, class); return (c.0 as char, class);
} }
let char_class = char_class_non_ascii(self); let char_class = char_class_non_ascii(self);
if char_class == CharClass::Upper { if char_class == CharClass::Upper {

View File

@ -32,7 +32,6 @@ impl Matcher {
let mut needle_iter = needle.iter().rev().copied(); let mut needle_iter = needle.iter().rev().copied();
let mut needle_char = needle_iter.next().unwrap(); let mut needle_char = needle_iter.next().unwrap();
for (i, &c) in haystack[start..end].iter().enumerate().rev() { for (i, &c) in haystack[start..end].iter().enumerate().rev() {
println!("{c:?} {i} {needle_char:?}");
if c == needle_char { if c == needle_char {
let Some(next_needle_char) = needle_iter.next() else { let Some(next_needle_char) = needle_iter.next() else {
start += i; start += i;

View File

@ -1,4 +1,5 @@
use std::cmp::max; use std::cmp::max;
use std::mem::take;
use crate::chars::{Char, CharClass}; use crate::chars::{Char, CharClass};
use crate::matrix::{haystack, rows_mut, Matrix, MatrixCell, MatrixRow}; use crate::matrix::{haystack, rows_mut, Matrix, MatrixCell, MatrixRow};
@ -54,8 +55,6 @@ impl Matcher {
if INDICIES { if INDICIES {
matrix.reconstruct_optimal_path(needle, start as u32, indicies, best_match_end); matrix.reconstruct_optimal_path(needle, start as u32, indicies, best_match_end);
} }
println!("{indicies:?}");
println!("{}", max_score);
Some(max_score) Some(max_score)
} }
} }
@ -70,6 +69,7 @@ impl<H: Char> Matrix<'_, H> {
where where
H: PartialEq<N>, H: PartialEq<N>,
{ {
let haystack_len = self.haystack.len() as u16;
let mut row_iter = needle.iter().copied().zip(self.row_offs.iter_mut()); let mut row_iter = needle.iter().copied().zip(self.row_offs.iter_mut());
let (mut needle_char, mut row_start) = row_iter.next().unwrap(); let (mut needle_char, mut row_start) = row_iter.next().unwrap();
@ -86,6 +86,7 @@ impl<H: Char> Matrix<'_, H> {
let mut prev_score = 0u16; let mut prev_score = 0u16;
let mut matched = false; let mut matched = false;
let first_needle_char = needle[0]; let first_needle_char = needle[0];
let mut matrix_cells = 0;
for (i, ((c, matrix_cell), bonus_)) in col_iter { for (i, ((c, matrix_cell), bonus_)) in col_iter {
let class = c.char_class(config); let class = c.char_class(config);
@ -97,23 +98,21 @@ impl<H: Char> Matrix<'_, H> {
prev_class = class; prev_class = class;
let i = i as u16; let i = i as u16;
println!("{i} {needle_char:?} {c:?}");
if *c == needle_char { if *c == needle_char {
// save the first idx of each char // save the first idx of each char
if let Some(next) = row_iter.next() { if let Some(next) = row_iter.next() {
matrix_cells += haystack_len - i;
*row_start = i; *row_start = i;
(needle_char, row_start) = next; (needle_char, row_start) = next;
} else { } else if !matched {
if !matched { matrix_cells += haystack_len - i;
*row_start = i; *row_start = i;
}
// we have atleast one match // we have atleast one match
matched = true; matched = true;
} }
} }
if *c == first_needle_char { if *c == first_needle_char {
let score = SCORE_MATCH + bonus * BONUS_FIRST_CHAR_MULTIPLIER; let score = SCORE_MATCH + bonus * BONUS_FIRST_CHAR_MULTIPLIER;
println!("start match {score}");
matrix_cell.consecutive_chars = 1; matrix_cell.consecutive_chars = 1;
if needle.len() == 1 && score > max_score { if needle.len() == 1 && score > max_score {
max_score = score; max_score = score;
@ -137,7 +136,7 @@ impl<H: Char> Matrix<'_, H> {
} }
prev_score = matrix_cell.score; prev_score = matrix_cell.score;
} }
self.cells = &mut take(&mut self.cells)[..matrix_cells as usize];
(max_score_pos, max_score, matched) (max_score_pos, max_score, matched)
} }
@ -208,7 +207,6 @@ impl<H: Char> Matrix<'_, H> {
} }
in_gap = score1 < score2; in_gap = score1 < score2;
let score = max(score1, score2); let score = max(score1, score2);
println!("{score} {score1} {score2}");
if i == needle.len() - 1 && score > max_score { if i == needle.len() - 1 && score > max_score {
max_score = score; max_score = score;
max_score_end = col as u16; max_score_end = col as u16;
@ -231,7 +229,7 @@ impl<H: Char> Matrix<'_, H> {
) { ) {
indicies.resize(needle.len(), 0); indicies.resize(needle.len(), 0);
let mut row_iter = self.rows_rev().zip(indicies.iter_mut()).peekable(); let mut row_iter = self.rows_rev().zip(indicies.iter_mut().rev()).peekable();
let (mut row, mut matched_col_idx) = row_iter.next().unwrap(); let (mut row, mut matched_col_idx) = row_iter.next().unwrap();
let mut next_row: Option<MatrixRow> = None; let mut next_row: Option<MatrixRow> = None;
let mut col = best_match_end; let mut col = best_match_end;
@ -239,7 +237,7 @@ impl<H: Char> Matrix<'_, H> {
let haystack_len = self.haystack.len() as u16; let haystack_len = self.haystack.len() as u16;
loop { loop {
let score = row.cells[col as usize].score; let score = row[col].score;
let mut score1 = 0; let mut score1 = 0;
let mut score2 = 0; let mut score2 = 0;
if let Some(&(prev_row, _)) = row_iter.peek() { if let Some(&(prev_row, _)) = row_iter.peek() {
@ -250,19 +248,20 @@ impl<H: Char> Matrix<'_, H> {
if col > row.off { if col > row.off {
score2 = row[col - 1].score; score2 = row[col - 1].score;
} }
println!("{score} {score2} {score1} {prefer_match}");
let mut new_prefer_match = row[col].consecutive_chars > 1; let mut new_prefer_match = row[col].consecutive_chars > 1;
if !new_prefer_match && col + 1 < haystack_len { if !new_prefer_match && col + 1 < haystack_len {
if let Some(next_row) = next_row { if let Some(next_row) = next_row {
new_prefer_match = next_row[col + 1].consecutive_chars > 0 if col + 1 > next_row.off {
new_prefer_match = next_row[col + 1].consecutive_chars > 0
}
} }
} }
if score > score1 && (score > score2 || score == score2 && prefer_match) { if score > score1 && (score > score2 || score == score2 && prefer_match) {
*matched_col_idx = col as u32 + start; *matched_col_idx = col as u32 + start;
next_row = Some(row); next_row = Some(row);
let Some(next) = row_iter.next() else { let Some(next) = row_iter.next() else {
break; break;
}; };
(row, matched_col_idx) = next (row, matched_col_idx) = next
} }
prefer_match = new_prefer_match; prefer_match = new_prefer_match;

View File

@ -10,11 +10,12 @@ mod prefilter;
mod score; mod score;
mod utf32_str; mod utf32_str;
// #[cfg(test)] #[cfg(test)]
// mod tests; mod tests;
pub use config::MatcherConfig; pub use config::MatcherConfig;
use crate::chars::AsciiChar;
use crate::matrix::MatrixSlab; use crate::matrix::MatrixSlab;
use crate::utf32_str::Utf32Str; use crate::utf32_str::Utf32Str;
@ -61,12 +62,29 @@ impl Matcher {
assert!(haystack.len() <= u32::MAX as usize); assert!(haystack.len() <= u32::MAX as usize);
self.fuzzy_matcher_impl::<false>(haystack, needle, &mut Vec::new()) self.fuzzy_matcher_impl::<false>(haystack, needle, &mut Vec::new())
} }
pub fn fuzzy_indicies(
&mut self,
haystack: Utf32Str<'_>,
needle: Utf32Str<'_>,
indidies: &mut Vec<u32>,
) -> Option<u16> {
assert!(haystack.len() <= u32::MAX as usize);
self.fuzzy_matcher_impl::<true>(haystack, needle, indidies)
}
fn fuzzy_matcher_impl<const INDICIES: bool>( fn fuzzy_matcher_impl<const INDICIES: bool>(
&mut self, &mut self,
haystack: Utf32Str<'_>, haystack: Utf32Str<'_>,
needle_: Utf32Str<'_>, needle_: Utf32Str<'_>,
indidies: &mut Vec<u32>, indidies: &mut Vec<u32>,
) -> Option<u16> { ) -> Option<u16> {
if needle_.len() > haystack.len() {
return None;
}
// if needle_.len() == haystack.len() {
// return self.exact_match();
// }
assert!( assert!(
haystack.len() <= u32::MAX as usize, haystack.len() <= u32::MAX as usize,
"fuzzy matching is only support for up to 2^32-1 codepoints" "fuzzy matching is only support for up to 2^32-1 codepoints"
@ -74,8 +92,13 @@ impl Matcher {
match (haystack, needle_) { match (haystack, needle_) {
(Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => { (Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => {
let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle)?; let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle)?;
self.fuzzy_match_optimal::<INDICIES, u8, u8>( self.fuzzy_match_optimal::<INDICIES, AsciiChar, AsciiChar>(
haystack, needle, start, greedy_end, end, indidies, AsciiChar::cast(haystack),
AsciiChar::cast(needle),
start,
greedy_end,
end,
indidies,
) )
} }
(Utf32Str::Ascii(_), Utf32Str::Unicode(_)) => { (Utf32Str::Ascii(_), Utf32Str::Unicode(_)) => {
@ -84,16 +107,15 @@ impl Matcher {
None None
} }
(Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => { (Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => {
todo!() let (start, end) = self.prefilter_non_ascii(haystack, needle_)?;
// let (start, end) = self.prefilter_non_ascii(haystack, needle_)?; self.fuzzy_match_optimal::<INDICIES, char, AsciiChar>(
// self.fuzzy_match_optimal::<INDICIES, char, u8>( haystack,
// haystack, AsciiChar::cast(needle),
// needle, start,
// start, start + 1,
// start + 1, end,
// end, indidies,
// indidies, )
// )
} }
(Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => { (Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => {
let (start, end) = self.prefilter_non_ascii(haystack, needle_)?; let (start, end) = self.prefilter_non_ascii(haystack, needle_)?;

View File

@ -82,7 +82,7 @@ pub(crate) struct MatrixCell {
impl Debug for MatrixCell { impl Debug for MatrixCell {
fn fmt(&self, f: &mut Formatter<'_>) -> Result { fn fmt(&self, f: &mut Formatter<'_>) -> Result {
(self.score, self.consecutive_chars).fmt(f) write!(f, "({}, {})", self.score, self.consecutive_chars)
} }
} }
@ -94,7 +94,7 @@ pub(crate) struct HaystackChar<C: Char> {
impl<C: Char> Debug for HaystackChar<C> { impl<C: Char> Debug for HaystackChar<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result { fn fmt(&self, f: &mut Formatter<'_>) -> Result {
(self.char, self.bonus).fmt(f) write!(f, "({:?}, {})", self.char, self.bonus)
} }
} }
@ -103,18 +103,26 @@ pub(crate) struct MatrixRow<'a> {
pub off: u16, pub off: u16,
pub cells: &'a [MatrixCell], pub cells: &'a [MatrixCell],
} }
/// Intexing returns the cell that corresponds to colmun `col` in this row,
/// this is not the same as directly indexing the cells array because every row
/// starts at a column offset which needs to be accounted for
impl Index<u16> for MatrixRow<'_> { impl Index<u16> for MatrixRow<'_> {
type Output = MatrixCell; type Output = MatrixCell;
fn index(&self, index: u16) -> &Self::Output { #[inline(always)]
&self.cells[index as usize] fn index(&self, col: u16) -> &Self::Output {
&self.cells[(col - self.off) as usize]
} }
} }
impl Debug for MatrixRow<'_> { impl Debug for MatrixRow<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result { fn fmt(&self, f: &mut Formatter<'_>) -> Result {
let mut f = f.debug_list(); let mut f = f.debug_list();
f.entries((0..self.off).map(|_| &(0, 0))); f.entries((0..self.off).map(|_| &MatrixCell {
score: 0,
consecutive_chars: 0,
}));
f.entries(self.cells.iter()); f.entries(self.cells.iter());
f.finish() f.finish()
} }
@ -250,7 +258,7 @@ impl MatrixSlab {
let matrix_layout = MatrixLayout::<C>::new( let matrix_layout = MatrixLayout::<C>::new(
haystack_.len(), haystack_.len(),
needle_len, needle_len,
(haystack_.len() - needle_len / 2) * needle_len, (haystack_.len() + 1 - needle_len / 2) * needle_len,
); );
if matrix_layout.layout.size() > size_of::<MatrixData>() { if matrix_layout.layout.size() > size_of::<MatrixData>() {
return None; return None;

View File

@ -38,7 +38,8 @@ impl Matcher {
haystack = &haystack[idx..]; haystack = &haystack[idx..];
} }
let end = eager_end let end = eager_end
+ find_ascii_ignore_case_rev(*needle.last().unwrap(), haystack).unwrap_or(0); + find_ascii_ignore_case_rev(*needle.last().unwrap(), haystack)
.map_or(0, |i| i + 1);
Some((start, eager_end, end)) Some((start, eager_end, end))
} else { } else {
let start = memchr(needle[0], haystack)?; let start = memchr(needle[0], haystack)?;
@ -49,7 +50,7 @@ impl Matcher {
eager_end += idx; eager_end += idx;
haystack = &haystack[idx..]; haystack = &haystack[idx..];
} }
let end = eager_end + memrchr(*needle.last().unwrap(), haystack).unwrap_or(0); let end = eager_end + memrchr(*needle.last().unwrap(), haystack).map_or(0, |i| i + 1);
Some((start, eager_end, end)) Some((start, eager_end, end))
} }
} }
@ -64,9 +65,11 @@ impl Matcher {
.iter() .iter()
.position(|c| c.normalize(&self.config) == needle_char)?; .position(|c| c.normalize(&self.config) == needle_char)?;
let needle_char = needle.last(); let needle_char = needle.last();
let end = haystack[start..] let end = start + haystack.len()
.iter() - haystack[start..]
.position(|c| c.normalize(&self.config) == needle_char)?; .iter()
.rev()
.position(|c| c.normalize(&self.config) == needle_char)?;
Some((start, end)) Some((start, end))
} }

View File

@ -1,8 +1,10 @@
use crate::config::{ use crate::chars::Char;
use crate::score::{
BONUS_BOUNDARY, BONUS_CAMEL123, BONUS_CONSECUTIVE, BONUS_FIRST_CHAR_MULTIPLIER, BONUS_NON_WORD, BONUS_BOUNDARY, BONUS_CAMEL123, BONUS_CONSECUTIVE, BONUS_FIRST_CHAR_MULTIPLIER, BONUS_NON_WORD,
PENALTY_GAP_EXTENSION, PENALTY_GAP_START, SCORE_MATCH, PENALTY_GAP_EXTENSION, PENALTY_GAP_START, SCORE_MATCH,
}; };
use crate::{CaseMatching, Matcher, MatcherConfig}; use crate::utf32_str::Utf32Str;
use crate::{Matcher, MatcherConfig};
pub fn assert_matches( pub fn assert_matches(
use_v1: bool, use_v1: bool,
@ -12,13 +14,8 @@ pub fn assert_matches(
cases: &[(&str, &str, u32, u32, u16)], cases: &[(&str, &str, u32, u32, u16)],
) { ) {
let mut config = MatcherConfig { let mut config = MatcherConfig {
use_v1,
normalize, normalize,
case_matching: if case_sensitive { ignore_case: !case_sensitive,
CaseMatching::Respect
} else {
CaseMatching::Ignore
},
..MatcherConfig::DEFAULT ..MatcherConfig::DEFAULT
}; };
if path { if path {
@ -26,11 +23,31 @@ pub fn assert_matches(
} }
let mut matcher = Matcher::new(config); let mut matcher = Matcher::new(config);
let mut indicies = Vec::new(); let mut indicies = Vec::new();
let mut needle_buf = Vec::new();
let mut haystack_buf = Vec::new();
for &(haystack, needle, start, end, mut score) in cases { for &(haystack, needle, start, end, mut score) in cases {
score += needle.chars().count() as u16 * SCORE_MATCH; let needle = if !case_sensitive {
let query = matcher.compile_query(needle); needle.to_lowercase()
let res = matcher.fuzzy_indicies(&query, haystack, &mut indicies); } else {
assert_eq!(res, Some(score), "{needle:?} did not match {haystack:?}"); needle.to_owned()
};
let needle = Utf32Str::new(&needle, &mut needle_buf);
let haystack = Utf32Str::new(haystack, &mut haystack_buf);
score += needle.len() as u16 * SCORE_MATCH;
let res = matcher.fuzzy_indicies(haystack, needle, &mut indicies);
let match_chars: Vec<_> = indicies
.iter()
.map(|&i| haystack.get(i).normalize(&matcher.config))
.collect();
let needle_chars: Vec<_> = needle.chars().collect();
assert_eq!(
res,
Some(score),
"{needle:?} did not match {haystack:?}: {match_chars:?}"
);
assert_eq!(match_chars, needle_chars, "match indicies are incorrect");
assert_eq!( assert_eq!(
indicies.first().copied()..indicies.last().map(|&i| i + 1), indicies.first().copied()..indicies.last().map(|&i| i + 1),
Some(start)..Some(end), Some(start)..Some(end),

View File

@ -1,4 +1,5 @@
use std::ops::{Bound, RangeBounds}; use std::ops::{Bound, RangeBounds};
use std::slice;
/// A UTF32 encoded (char array) String that can be used as an input to fuzzy matching. /// A UTF32 encoded (char array) String that can be used as an input to fuzzy matching.
/// ///
@ -108,16 +109,25 @@ impl<'a> Utf32Str<'a> {
Utf32Str::Unicode(codepoints) => codepoints[codepoints.len()], Utf32Str::Unicode(codepoints) => codepoints[codepoints.len()],
} }
} }
pub fn chars(&self) -> Chars<'_> {
match self {
Utf32Str::Ascii(bytes) => Chars::Ascii(bytes.iter()),
Utf32Str::Unicode(codepoints) => Chars::Unicode(codepoints.iter()),
}
}
} }
// impl Str for &[char] { pub enum Chars<'a> {
// type Chars; Ascii(slice::Iter<'a, u8>),
Unicode(slice::Iter<'a, char>),
}
impl<'a> Iterator for Chars<'a> {
type Item = char;
// fn chars(&self) -> Self::Chars { fn next(&mut self) -> Option<Self::Item> {
// todo!() match self {
// } Chars::Ascii(iter) => iter.next().map(|&c| c as char),
Chars::Unicode(iter) => iter.next().copied(),
// fn slice(&self, range: impl RangeBounds<u32>) { }
// todo!() }
// } }
// }