diff --git a/gitbutler-diff/src/signature.rs b/gitbutler-diff/src/signature.rs index 24b3b1bf2..36c1ebb09 100644 --- a/gitbutler-diff/src/signature.rs +++ b/gitbutler-diff/src/signature.rs @@ -14,15 +14,11 @@ //! Otherwise, neither the length prefix imposed by `(de)serialize_bytes()` nor the //! terrible compaction and optimization of `(de)serialize_tuple()` are acceptable. -// FIXME(qix-): There are a ton of identifiers in here that make no sense and -// FIXME(qix-): were copied over from the exploratory data-science-ey code that -// FIXME(qix-): need to be cleaned up. PR welcome! - const BITS: usize = 3; const SHIFT: usize = 8 - BITS; -const SIG_ENTRIES: usize = (1 << BITS) * (1 << BITS); -const SIG_BYTES: usize = SIG_ENTRIES * ::core::mem::size_of::(); -const TOTAL_BYTES: usize = SIG_BYTES + 4 + 1; // we encode a 4-byte length at the beginning, along with a version byte +const FINGERPRINT_ENTRIES: usize = (1 << BITS) * (1 << BITS); +const FINGERPRINT_BYTES: usize = FINGERPRINT_ENTRIES * ::core::mem::size_of::(); +const TOTAL_BYTES: usize = 1 + 4 + FINGERPRINT_BYTES; // we encode a version byte and a 4-byte length at the beginning // NOTE: This is not efficient if `SigBucket` is 1 byte (u8). // NOTE: If `SigBucket` is changed to a u8, then the implementation @@ -80,42 +76,44 @@ impl Signature { /// about the signature or the original file contents. /// /// Do not use for any security-related purposes. - pub fn score_str>(&self, s: S) -> f64 { + pub fn score_str>(&self, other: S) -> f64 { if self.0[0] != 0 { panic!("unsupported signature version"); } - let original_length = u32::from_le_bytes(self.0[1..5].try_into().unwrap()); - - let s = s.as_ref(); - - let s_s: String = s.chars().filter(|&x| !char::is_whitespace(x)).collect(); - let s = s_s.as_bytes(); - - if original_length < 2 || s.len() < 2 { + let original_length = u32::from_le_bytes(self.0[1..5].try_into().expect("invalid length")); + if original_length < 2 { return 0.0; } - let mut intersection_size = 0usize; + let other = other.as_ref(); + let other_string: String = other.chars().filter(|&x| !char::is_whitespace(x)).collect(); + let other = other_string.as_bytes(); - let mut wb = self.bucket_iter().collect::>(); + if other.len() < 2 { + return 0.0; + } - for (b1, b2) in bigrams(s) { - let b1 = b1 >> SHIFT; - let b2 = b2 >> SHIFT; - let ix = ((b1 as usize) << BITS) | (b2 as usize); - if wb[ix] > 0 { - wb[ix] = wb[ix].saturating_sub(1); - intersection_size += 1; + let mut matching_bigrams: usize = 0; + + let mut self_buckets = self.bucket_iter().collect::>(); + + for (left, right) in bigrams(other) { + let left = left >> SHIFT; + let right = right >> SHIFT; + let index = ((left as usize) << BITS) | (right as usize); + if self_buckets[index] > 0 { + self_buckets[index] = self_buckets[index] - 1; + matching_bigrams += 1; } } - (2 * intersection_size) as f64 / (original_length as usize + s.len() - 2) as f64 + (2 * matching_bigrams) as f64 / (original_length as usize + other.len() - 2) as f64 } fn bucket_iter(&self) -> impl Iterator + '_ { unsafe { - self.0[(TOTAL_BYTES - SIG_BYTES)..] + self.0[(TOTAL_BYTES - FINGERPRINT_BYTES)..] .as_chunks_unchecked::<{ ::core::mem::size_of::() }>() .iter() .map(|ch: &[u8; ::core::mem::size_of::()]| SigBucket::from_le_bytes(*ch)) @@ -125,45 +123,50 @@ impl Signature { impl> From for Signature { #[inline] - fn from(s: S) -> Self { - let s = s.as_ref(); + fn from(source: S) -> Self { + let source = source.as_ref(); + let source_string: String = source + .chars() + .filter(|&x| !char::is_whitespace(x)) + .collect(); + let source = source_string.as_bytes(); - let a_s: String = s.chars().filter(|&x| !char::is_whitespace(x)).collect(); - let a = a_s.as_bytes(); - - let a_len: u32 = a + let source_len: u32 = source .len() .try_into() .expect("strings with a byte-length above u32::MAX are not supported"); - let mut a_res = [0; TOTAL_BYTES]; - a_res[0] = 0; // version byte - a_res[1..5].copy_from_slice(&a_len.to_le_bytes()); // length + let mut bytes = [0; TOTAL_BYTES]; + bytes[0] = 0; // version byte (0) + bytes[1..5].copy_from_slice(&source_len.to_le_bytes()); // next 4 bytes are the length - if a_len >= 2 { - let mut a_bigrams = [0 as SigBucket; SIG_ENTRIES]; + if source_len >= 2 { + let mut buckets = [0 as SigBucket; FINGERPRINT_ENTRIES]; - for (b1, b2) in bigrams(a) { - let b1 = b1 >> SHIFT; - let b2 = b2 >> SHIFT; - let encoded_bigram = ((b1 as usize) << BITS) | (b2 as usize); - a_bigrams[encoded_bigram] = a_bigrams[encoded_bigram].saturating_add(1); + for (left, right) in bigrams(source) { + let left = left >> SHIFT; + let right = right >> SHIFT; + let index = ((left as usize) << BITS) | (right as usize); + buckets[index] = buckets[index].saturating_add(1); } // NOTE: This is not efficient if `SigBucket` is 1 byte (u8). - let mut offset = TOTAL_BYTES - SIG_BYTES; - for bucket in a_bigrams { + let mut offset = TOTAL_BYTES - FINGERPRINT_BYTES; + for bucket in buckets { let start = offset; let end = start + ::core::mem::size_of::(); - a_res[start..end].copy_from_slice(&bucket.to_le_bytes()); + bytes[start..end].copy_from_slice(&bucket.to_le_bytes()); offset = end; } } - Self(a_res) + Self(bytes) } } +/// Copies the passed bytes twice and zips them together with a one-byte offset. +/// This produces an iterator of the bigrams (pairs of consecutive bytes) in the input. +/// For example, the bigrams of 1, 2, 3, 4, 5 would be (1, 2), (2, 3), (3, 4), (4, 5). #[inline] fn bigrams(s: &[u8]) -> impl Iterator + '_ { s.iter().copied().zip(s.iter().skip(1).copied()) @@ -173,61 +176,47 @@ fn bigrams(s: &[u8]) -> impl Iterator + '_ { mod tests { use super::*; + macro_rules! assert_score { + ($sig:ident, $s:expr, $e:expr) => { + let score = $sig.score_str($s); + if (score - $e).abs() >= 0.1 { + panic!( + "expected score of {} for string {:?}, got {}", + $e, $s, score + ); + } + }; + } + #[test] fn score_signature() { let sig = Signature::from("hello world"); - macro_rules! assert_score { - ($s:expr, $e:expr) => { - if (sig.score_str($s) - $e).abs() >= 0.1 { - panic!( - "expected score of {} for string {:?}, got {}", - $e, - $s, - sig.score_str($s) - ); - } - }; - } - // NOTE: The scores here are not exact, but are close enough // to be useful for testing purposes, hence why some have the same // "score" but different strings. - assert_score!("hello world", 1.0); - assert_score!("hello world!", 0.95); - assert_score!("hello world!!", 0.9); - assert_score!("hello world!!!", 0.85); - assert_score!("hello world!!!!", 0.8); - assert_score!("hello world!!!!!", 0.75); - assert_score!("hello world!!!!!!", 0.7); - assert_score!("hello world!!!!!!!", 0.65); - assert_score!("hello world!!!!!!!!", 0.62); - assert_score!("hello world!!!!!!!!!", 0.6); - assert_score!("hello world!!!!!!!!!!", 0.55); + assert_score!(sig, "hello world", 1.0); + assert_score!(sig, "hello world!", 0.95); + assert_score!(sig, "hello world!!", 0.9); + assert_score!(sig, "hello world!!!", 0.85); + assert_score!(sig, "hello world!!!!", 0.8); + assert_score!(sig, "hello world!!!!!", 0.75); + assert_score!(sig, "hello world!!!!!!", 0.7); + assert_score!(sig, "hello world!!!!!!!", 0.65); + assert_score!(sig, "hello world!!!!!!!!", 0.62); + assert_score!(sig, "hello world!!!!!!!!!", 0.6); + assert_score!(sig, "hello world!!!!!!!!!!", 0.55); } #[test] fn score_ignores_whitespace() { let sig = Signature::from("hello world"); - macro_rules! assert_score { - ($s:expr, $e:expr) => { - if (sig.score_str($s) - $e).abs() >= 0.1 { - panic!( - "expected score of {} for string {:?}, got {}", - $e, - $s, - sig.score_str($s) - ); - } - }; - } - - assert_score!("hello world", 1.0); - assert_score!("hello world ", 1.0); - assert_score!("hello\nworld ", 1.0); - assert_score!("hello\n\tworld ", 1.0); - assert_score!("\t\t hel lo\n\two rld \t\t", 1.0); + assert_score!(sig, "hello world", 1.0); + assert_score!(sig, "hello world ", 1.0); + assert_score!(sig, "hello\nworld ", 1.0); + assert_score!(sig, "hello\n\tworld ", 1.0); + assert_score!(sig, "\t\t hel lo\n\two rld \t\t", 1.0); } const TEXT1: &str = include_str!("../fixture/text1.txt");