add API for greedly algorithm and make sure it passes tests

This commit is contained in:
Pascal Kuthe 2023-07-20 16:33:14 +02:00
parent d844ab7f3b
commit 52f1712a78
No known key found for this signature in database
GPG Key ID: D715E8655AE166A6
7 changed files with 144 additions and 53 deletions

View File

@ -4,7 +4,7 @@ use crate::Matcher;
impl Matcher { impl Matcher {
/// greedy fallback algorithm, much faster (linear time) but reported scores/indicies /// greedy fallback algorithm, much faster (linear time) but reported scores/indicies
/// might not be the best match /// might not be the best match
pub(crate) fn fuzzy_match_greedy<const INDICES: bool, H: Char + PartialEq<N>, N: Char>( pub(crate) fn fuzzy_match_greedy_<const INDICES: bool, H: Char + PartialEq<N>, N: Char>(
&mut self, &mut self,
haystack: &[H], haystack: &[H],
needle: &[N], needle: &[N],

View File

@ -23,7 +23,7 @@ impl Matcher {
// to avoid the slow O(mn) time complexity for large inputs. Furthermore, it allows // to avoid the slow O(mn) time complexity for large inputs. Furthermore, it allows
// us to treat needle indices as u16 // us to treat needle indices as u16
let Some(mut matrix) = self.slab.alloc(&haystack[start..end], needle.len()) else { let Some(mut matrix) = self.slab.alloc(&haystack[start..end], needle.len()) else {
return self.fuzzy_match_greedy::<INDICES, H, N>( return self.fuzzy_match_greedy_::<INDICES, H, N>(
haystack, haystack,
needle, needle,
start, start,

View File

@ -91,7 +91,7 @@ impl Matcher {
); );
match (haystack, needle_) { match (haystack, needle_) {
(Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => { (Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => {
let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle)?; let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle, false)?;
self.fuzzy_match_optimal::<INDICES, AsciiChar, AsciiChar>( self.fuzzy_match_optimal::<INDICES, AsciiChar, AsciiChar>(
AsciiChar::cast(haystack), AsciiChar::cast(haystack),
AsciiChar::cast(needle), AsciiChar::cast(needle),
@ -107,7 +107,7 @@ impl Matcher {
None None
} }
(Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => { (Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => {
let (start, end) = self.prefilter_non_ascii(haystack, needle_)?; let (start, end) = self.prefilter_non_ascii(haystack, needle_, false)?;
self.fuzzy_match_optimal::<INDICES, char, AsciiChar>( self.fuzzy_match_optimal::<INDICES, char, AsciiChar>(
haystack, haystack,
AsciiChar::cast(needle), AsciiChar::cast(needle),
@ -118,7 +118,7 @@ impl Matcher {
) )
} }
(Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => { (Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => {
let (start, end) = self.prefilter_non_ascii(haystack, needle_)?; let (start, end) = self.prefilter_non_ascii(haystack, needle_, false)?;
self.fuzzy_match_optimal::<INDICES, char, char>( self.fuzzy_match_optimal::<INDICES, char, char>(
haystack, haystack,
needle, needle,
@ -130,30 +130,77 @@ impl Matcher {
} }
} }
} }
pub fn fuzzy_match_greedy(
// pub fn fuzzy_indices( &mut self,
// &mut self, haystack: Utf32Str<'_>,
// query: &Query, needle: Utf32Str<'_>,
// mut haystack: Utf32Str<'_>, ) -> Option<u16> {
// indices: &mut Vec<u32>, assert!(haystack.len() <= u32::MAX as usize);
// ) -> Option<u16> { self.fuzzy_match_greedy_impl::<false>(haystack, needle, &mut Vec::new())
// if haystack.len() > u32::MAX as usize { }
// haystack = &haystack[..u32::MAX as usize]
// } pub fn fuzzy_indices_greedy(
// println!( &mut self,
// "start {haystack:?}, {:?} {} {}", haystack: Utf32Str<'_>,
// query.needle_chars, query.ignore_case, query.is_ascii needle: Utf32Str<'_>,
// ); indidies: &mut Vec<u32>,
// if self.config.use_v1 { ) -> Option<u16> {
// if query.is_ascii && !self.config.normalize { assert!(haystack.len() <= u32::MAX as usize);
// self.fuzzy_matcher_v1::<true, true>(query, haystack, indices) self.fuzzy_match_greedy_impl::<true>(haystack, needle, indidies)
// } else { }
// self.fuzzy_matcher_v1::<true, false>(query, haystack, indices)
// } fn fuzzy_match_greedy_impl<const INDICES: bool>(
// } else if query.is_ascii && !self.config.normalize { &mut self,
// self.fuzzy_matcher_v2::<true, true>(query, haystack, indices) haystack: Utf32Str<'_>,
// } else { needle_: Utf32Str<'_>,
// self.fuzzy_matcher_v2::<true, false>(query, haystack, indices) indidies: &mut Vec<u32>,
// } ) -> Option<u16> {
// } if needle_.len() > haystack.len() {
return None;
}
// if needle_.len() == haystack.len() {
// return self.exact_match();
// }
assert!(
haystack.len() <= u32::MAX as usize,
"fuzzy matching is only support for up to 2^32-1 codepoints"
);
match (haystack, needle_) {
(Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => {
let (start, greedy_end, _) = self.prefilter_ascii(haystack, needle, true)?;
self.fuzzy_match_greedy_::<INDICES, AsciiChar, AsciiChar>(
AsciiChar::cast(haystack),
AsciiChar::cast(needle),
start,
greedy_end,
indidies,
)
}
(Utf32Str::Ascii(_), Utf32Str::Unicode(_)) => {
// a purely ascii haystack can never be transformed to match
// a needle that contains non-ascii chars since we don't allow gaps
None
}
(Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => {
let (start, _) = self.prefilter_non_ascii(haystack, needle_, true)?;
self.fuzzy_match_greedy_::<INDICES, char, AsciiChar>(
haystack,
AsciiChar::cast(needle),
start,
start + 1,
indidies,
)
}
(Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => {
let (start, _) = self.prefilter_non_ascii(haystack, needle_, true)?;
self.fuzzy_match_greedy_::<INDICES, char, char>(
haystack,
needle,
start,
start + 1,
indidies,
)
}
}
}
} }

View File

@ -27,31 +27,41 @@ impl Matcher {
&self, &self,
mut haystack: &[u8], mut haystack: &[u8],
needle: &[u8], needle: &[u8],
only_greedy: bool,
) -> Option<(usize, usize, usize)> { ) -> Option<(usize, usize, usize)> {
if self.config.ignore_case { if self.config.ignore_case {
let start = find_ascii_ignore_case(needle[0], haystack)?; let start = find_ascii_ignore_case(needle[0], haystack)?;
let mut eager_end = start + 1; let mut greedy_end = start + 1;
haystack = &haystack[eager_end..]; haystack = &haystack[greedy_end..];
for &c in &needle[1..] { for &c in &needle[1..] {
let idx = find_ascii_ignore_case(c, haystack)? + 1; let idx = find_ascii_ignore_case(c, haystack)? + 1;
eager_end += idx; greedy_end += idx;
haystack = &haystack[idx..]; haystack = &haystack[idx..];
} }
let end = eager_end if only_greedy {
Some((start, greedy_end, greedy_end))
} else {
let end = greedy_end
+ find_ascii_ignore_case_rev(*needle.last().unwrap(), haystack) + find_ascii_ignore_case_rev(*needle.last().unwrap(), haystack)
.map_or(0, |i| i + 1); .map_or(0, |i| i + 1);
Some((start, eager_end, end)) Some((start, greedy_end, end))
}
} else { } else {
let start = memchr(needle[0], haystack)?; let start = memchr(needle[0], haystack)?;
let mut eager_end = start + 1; let mut greedy_end = start + 1;
haystack = &haystack[eager_end..]; haystack = &haystack[greedy_end..];
for &c in &needle[1..] { for &c in &needle[1..] {
let idx = memchr(c, haystack)? + 1; let idx = memchr(c, haystack)? + 1;
eager_end += idx; greedy_end += idx;
haystack = &haystack[idx..]; haystack = &haystack[idx..];
} }
let end = eager_end + memrchr(*needle.last().unwrap(), haystack).map_or(0, |i| i + 1); if only_greedy {
Some((start, eager_end, end)) Some((start, greedy_end, greedy_end))
} else {
let end =
greedy_end + memrchr(*needle.last().unwrap(), haystack).map_or(0, |i| i + 1);
Some((start, greedy_end, end))
}
} }
} }
@ -59,12 +69,16 @@ impl Matcher {
&self, &self,
haystack: &[char], haystack: &[char],
needle: Utf32Str<'_>, needle: Utf32Str<'_>,
only_greedy: bool,
) -> Option<(usize, usize)> { ) -> Option<(usize, usize)> {
let needle_char = needle.get(0); let needle_char = needle.get(0);
let start = haystack let start = haystack
.iter() .iter()
.position(|c| c.normalize(&self.config) == needle_char)?; .position(|c| c.normalize(&self.config) == needle_char)?;
let needle_char = needle.last(); let needle_char = needle.last();
if only_greedy {
Some((start, start + 1))
} else {
let end = start + haystack.len() let end = start + haystack.len()
- haystack[start..] - haystack[start..]
.iter() .iter()
@ -74,3 +88,4 @@ impl Matcher {
Some((start, end)) Some((start, end))
} }
} }
}

View File

@ -78,6 +78,7 @@ impl Matcher {
indices: &mut Vec<u32>, indices: &mut Vec<u32>,
) -> u16 { ) -> u16 {
if INDICES { if INDICES {
indices.clear();
indices.reserve(needle.len()); indices.reserve(needle.len());
} }
@ -95,15 +96,18 @@ impl Matcher {
if INDICES { if INDICES {
indices.push(start as u32) indices.push(start as u32)
} }
let mut first_bonus = self.bonus_for(prev_class, haystack[0].char_class(&self.config)); let class = haystack[start].char_class(&self.config);
let mut first_bonus = self.bonus_for(prev_class, class);
let mut score = SCORE_MATCH + first_bonus * BONUS_FIRST_CHAR_MULTIPLIER; let mut score = SCORE_MATCH + first_bonus * BONUS_FIRST_CHAR_MULTIPLIER;
prev_class = class;
needle_char = *needle_iter.next().unwrap_or(&needle_char);
for (i, c) in haystack[start + 1..end].iter().enumerate() { for (i, c) in haystack[start + 1..end].iter().enumerate() {
let class = c.char_class(&self.config); let class = c.char_class(&self.config);
let c = c.normalize(&self.config); let c = c.normalize(&self.config);
if c == needle_char { if c == needle_char {
if INDICES { if INDICES {
indices.push(i as u32 + start as u32) indices.push(i as u32 + start as u32 + 1)
} }
let mut bonus = self.bonus_for(prev_class, class); let mut bonus = self.bonus_for(prev_class, class);
if consecutive == 0 { if consecutive == 0 {

View File

@ -35,7 +35,11 @@ pub fn assert_matches(
let haystack = Utf32Str::new(haystack, &mut haystack_buf); let haystack = Utf32Str::new(haystack, &mut haystack_buf);
score += needle.len() as u16 * SCORE_MATCH; score += needle.len() as u16 * SCORE_MATCH;
let res = matcher.fuzzy_indices(haystack, needle, &mut indices); let res = if use_v1 {
matcher.fuzzy_indices_greedy(haystack, needle, &mut indices)
} else {
matcher.fuzzy_indices(haystack, needle, &mut indices)
};
let match_chars: Vec<_> = indices let match_chars: Vec<_> = indices
.iter() .iter()
.map(|&i| haystack.get(i).normalize(&matcher.config)) .map(|&i| haystack.get(i).normalize(&matcher.config))

View File

@ -1,5 +1,5 @@
use std::ops::{Bound, RangeBounds}; use std::ops::{Bound, RangeBounds};
use std::slice; use std::{fmt, slice};
/// A UTF32 encoded (char array) String that can be used as an input to fuzzy matching. /// A UTF32 encoded (char array) String that can be used as an input to fuzzy matching.
/// ///
@ -27,7 +27,7 @@ use std::slice;
/// produce char indices (instead of utf8 offsets) annyway. With a /// produce char indices (instead of utf8 offsets) annyway. With a
/// codepoint basec representation like this the indices can be used /// codepoint basec representation like this the indices can be used
/// directly /// directly
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Utf32Str<'a> { pub enum Utf32Str<'a> {
/// A string represented as ASCII encoded bytes. /// A string represented as ASCII encoded bytes.
/// Correctness invariant: must only contain valid ASCII (<=127) /// Correctness invariant: must only contain valid ASCII (<=127)
@ -116,6 +116,27 @@ impl<'a> Utf32Str<'a> {
} }
} }
} }
impl fmt::Debug for Utf32Str<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\"")?;
for c in self.chars() {
for c in c.escape_debug() {
write!(f, "{c}")?
}
}
write!(f, "\"")
}
}
impl fmt::Display for Utf32Str<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\"")?;
for c in self.chars() {
write!(f, "{c}")?
}
write!(f, "\"")
}
}
pub enum Chars<'a> { pub enum Chars<'a> {
Ascii(slice::Iter<'a, u8>), Ascii(slice::Iter<'a, u8>),