add API for greedly algorithm and make sure it passes tests

This commit is contained in:
Pascal Kuthe 2023-07-20 16:33:14 +02:00
parent d844ab7f3b
commit 52f1712a78
No known key found for this signature in database
GPG Key ID: D715E8655AE166A6
7 changed files with 144 additions and 53 deletions

View File

@ -4,7 +4,7 @@ use crate::Matcher;
impl Matcher {
/// greedy fallback algorithm, much faster (linear time) but reported scores/indicies
/// might not be the best match
pub(crate) fn fuzzy_match_greedy<const INDICES: bool, H: Char + PartialEq<N>, N: Char>(
pub(crate) fn fuzzy_match_greedy_<const INDICES: bool, H: Char + PartialEq<N>, N: Char>(
&mut self,
haystack: &[H],
needle: &[N],

View File

@ -23,7 +23,7 @@ impl Matcher {
// to avoid the slow O(mn) time complexity for large inputs. Furthermore, it allows
// us to treat needle indices as u16
let Some(mut matrix) = self.slab.alloc(&haystack[start..end], needle.len()) else {
return self.fuzzy_match_greedy::<INDICES, H, N>(
return self.fuzzy_match_greedy_::<INDICES, H, N>(
haystack,
needle,
start,

View File

@ -91,7 +91,7 @@ impl Matcher {
);
match (haystack, needle_) {
(Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => {
let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle)?;
let (start, greedy_end, end) = self.prefilter_ascii(haystack, needle, false)?;
self.fuzzy_match_optimal::<INDICES, AsciiChar, AsciiChar>(
AsciiChar::cast(haystack),
AsciiChar::cast(needle),
@ -107,7 +107,7 @@ impl Matcher {
None
}
(Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => {
let (start, end) = self.prefilter_non_ascii(haystack, needle_)?;
let (start, end) = self.prefilter_non_ascii(haystack, needle_, false)?;
self.fuzzy_match_optimal::<INDICES, char, AsciiChar>(
haystack,
AsciiChar::cast(needle),
@ -118,7 +118,7 @@ impl Matcher {
)
}
(Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => {
let (start, end) = self.prefilter_non_ascii(haystack, needle_)?;
let (start, end) = self.prefilter_non_ascii(haystack, needle_, false)?;
self.fuzzy_match_optimal::<INDICES, char, char>(
haystack,
needle,
@ -130,30 +130,77 @@ impl Matcher {
}
}
}
pub fn fuzzy_match_greedy(
&mut self,
haystack: Utf32Str<'_>,
needle: Utf32Str<'_>,
) -> Option<u16> {
assert!(haystack.len() <= u32::MAX as usize);
self.fuzzy_match_greedy_impl::<false>(haystack, needle, &mut Vec::new())
}
// pub fn fuzzy_indices(
// &mut self,
// query: &Query,
// mut haystack: Utf32Str<'_>,
// indices: &mut Vec<u32>,
// ) -> Option<u16> {
// if haystack.len() > u32::MAX as usize {
// haystack = &haystack[..u32::MAX as usize]
// }
// println!(
// "start {haystack:?}, {:?} {} {}",
// query.needle_chars, query.ignore_case, query.is_ascii
// );
// if self.config.use_v1 {
// if query.is_ascii && !self.config.normalize {
// self.fuzzy_matcher_v1::<true, true>(query, haystack, indices)
// } else {
// self.fuzzy_matcher_v1::<true, false>(query, haystack, indices)
// }
// } else if query.is_ascii && !self.config.normalize {
// self.fuzzy_matcher_v2::<true, true>(query, haystack, indices)
// } else {
// self.fuzzy_matcher_v2::<true, false>(query, haystack, indices)
// }
pub fn fuzzy_indices_greedy(
&mut self,
haystack: Utf32Str<'_>,
needle: Utf32Str<'_>,
indidies: &mut Vec<u32>,
) -> Option<u16> {
assert!(haystack.len() <= u32::MAX as usize);
self.fuzzy_match_greedy_impl::<true>(haystack, needle, indidies)
}
fn fuzzy_match_greedy_impl<const INDICES: bool>(
&mut self,
haystack: Utf32Str<'_>,
needle_: Utf32Str<'_>,
indidies: &mut Vec<u32>,
) -> Option<u16> {
if needle_.len() > haystack.len() {
return None;
}
// if needle_.len() == haystack.len() {
// return self.exact_match();
// }
assert!(
haystack.len() <= u32::MAX as usize,
"fuzzy matching is only support for up to 2^32-1 codepoints"
);
match (haystack, needle_) {
(Utf32Str::Ascii(haystack), Utf32Str::Ascii(needle)) => {
let (start, greedy_end, _) = self.prefilter_ascii(haystack, needle, true)?;
self.fuzzy_match_greedy_::<INDICES, AsciiChar, AsciiChar>(
AsciiChar::cast(haystack),
AsciiChar::cast(needle),
start,
greedy_end,
indidies,
)
}
(Utf32Str::Ascii(_), Utf32Str::Unicode(_)) => {
// a purely ascii haystack can never be transformed to match
// a needle that contains non-ascii chars since we don't allow gaps
None
}
(Utf32Str::Unicode(haystack), Utf32Str::Ascii(needle)) => {
let (start, _) = self.prefilter_non_ascii(haystack, needle_, true)?;
self.fuzzy_match_greedy_::<INDICES, char, AsciiChar>(
haystack,
AsciiChar::cast(needle),
start,
start + 1,
indidies,
)
}
(Utf32Str::Unicode(haystack), Utf32Str::Unicode(needle)) => {
let (start, _) = self.prefilter_non_ascii(haystack, needle_, true)?;
self.fuzzy_match_greedy_::<INDICES, char, char>(
haystack,
needle,
start,
start + 1,
indidies,
)
}
}
}
}

View File

@ -27,31 +27,41 @@ impl Matcher {
&self,
mut haystack: &[u8],
needle: &[u8],
only_greedy: bool,
) -> Option<(usize, usize, usize)> {
if self.config.ignore_case {
let start = find_ascii_ignore_case(needle[0], haystack)?;
let mut eager_end = start + 1;
haystack = &haystack[eager_end..];
let mut greedy_end = start + 1;
haystack = &haystack[greedy_end..];
for &c in &needle[1..] {
let idx = find_ascii_ignore_case(c, haystack)? + 1;
eager_end += idx;
greedy_end += idx;
haystack = &haystack[idx..];
}
let end = eager_end
if only_greedy {
Some((start, greedy_end, greedy_end))
} else {
let end = greedy_end
+ find_ascii_ignore_case_rev(*needle.last().unwrap(), haystack)
.map_or(0, |i| i + 1);
Some((start, eager_end, end))
Some((start, greedy_end, end))
}
} else {
let start = memchr(needle[0], haystack)?;
let mut eager_end = start + 1;
haystack = &haystack[eager_end..];
let mut greedy_end = start + 1;
haystack = &haystack[greedy_end..];
for &c in &needle[1..] {
let idx = memchr(c, haystack)? + 1;
eager_end += idx;
greedy_end += idx;
haystack = &haystack[idx..];
}
let end = eager_end + memrchr(*needle.last().unwrap(), haystack).map_or(0, |i| i + 1);
Some((start, eager_end, end))
if only_greedy {
Some((start, greedy_end, greedy_end))
} else {
let end =
greedy_end + memrchr(*needle.last().unwrap(), haystack).map_or(0, |i| i + 1);
Some((start, greedy_end, end))
}
}
}
@ -59,12 +69,16 @@ impl Matcher {
&self,
haystack: &[char],
needle: Utf32Str<'_>,
only_greedy: bool,
) -> Option<(usize, usize)> {
let needle_char = needle.get(0);
let start = haystack
.iter()
.position(|c| c.normalize(&self.config) == needle_char)?;
let needle_char = needle.last();
if only_greedy {
Some((start, start + 1))
} else {
let end = start + haystack.len()
- haystack[start..]
.iter()
@ -73,4 +87,5 @@ impl Matcher {
Some((start, end))
}
}
}

View File

@ -78,6 +78,7 @@ impl Matcher {
indices: &mut Vec<u32>,
) -> u16 {
if INDICES {
indices.clear();
indices.reserve(needle.len());
}
@ -95,15 +96,18 @@ impl Matcher {
if INDICES {
indices.push(start as u32)
}
let mut first_bonus = self.bonus_for(prev_class, haystack[0].char_class(&self.config));
let class = haystack[start].char_class(&self.config);
let mut first_bonus = self.bonus_for(prev_class, class);
let mut score = SCORE_MATCH + first_bonus * BONUS_FIRST_CHAR_MULTIPLIER;
prev_class = class;
needle_char = *needle_iter.next().unwrap_or(&needle_char);
for (i, c) in haystack[start + 1..end].iter().enumerate() {
let class = c.char_class(&self.config);
let c = c.normalize(&self.config);
if c == needle_char {
if INDICES {
indices.push(i as u32 + start as u32)
indices.push(i as u32 + start as u32 + 1)
}
let mut bonus = self.bonus_for(prev_class, class);
if consecutive == 0 {

View File

@ -35,7 +35,11 @@ pub fn assert_matches(
let haystack = Utf32Str::new(haystack, &mut haystack_buf);
score += needle.len() as u16 * SCORE_MATCH;
let res = matcher.fuzzy_indices(haystack, needle, &mut indices);
let res = if use_v1 {
matcher.fuzzy_indices_greedy(haystack, needle, &mut indices)
} else {
matcher.fuzzy_indices(haystack, needle, &mut indices)
};
let match_chars: Vec<_> = indices
.iter()
.map(|&i| haystack.get(i).normalize(&matcher.config))

View File

@ -1,5 +1,5 @@
use std::ops::{Bound, RangeBounds};
use std::slice;
use std::{fmt, slice};
/// A UTF32 encoded (char array) String that can be used as an input to fuzzy matching.
///
@ -27,7 +27,7 @@ use std::slice;
/// produce char indices (instead of utf8 offsets) annyway. With a
/// codepoint basec representation like this the indices can be used
/// directly
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Utf32Str<'a> {
/// A string represented as ASCII encoded bytes.
/// Correctness invariant: must only contain valid ASCII (<=127)
@ -116,6 +116,27 @@ impl<'a> Utf32Str<'a> {
}
}
}
impl fmt::Debug for Utf32Str<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\"")?;
for c in self.chars() {
for c in c.escape_debug() {
write!(f, "{c}")?
}
}
write!(f, "\"")
}
}
impl fmt::Display for Utf32Str<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "\"")?;
for c in self.chars() {
write!(f, "{c}")?
}
write!(f, "\"")
}
}
pub enum Chars<'a> {
Ascii(slice::Iter<'a, u8>),