Merge pull request #2934 from MichaelOwenDyer/Refactor-signature-module

Refactor diff/signature.rs
This commit is contained in:
Josh Junon 2024-02-29 11:49:36 +01:00 committed by GitHub
commit e66b19f1b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,15 +14,11 @@
//! Otherwise, neither the length prefix imposed by `(de)serialize_bytes()` nor the
//! terrible compaction and optimization of `(de)serialize_tuple()` are acceptable.
// FIXME(qix-): There are a ton of identifiers in here that make no sense and
// FIXME(qix-): were copied over from the exploratory data-science-ey code that
// FIXME(qix-): need to be cleaned up. PR welcome!
const BITS: usize = 3;
const SHIFT: usize = 8 - BITS;
const SIG_ENTRIES: usize = (1 << BITS) * (1 << BITS);
const SIG_BYTES: usize = SIG_ENTRIES * ::core::mem::size_of::<SigBucket>();
const TOTAL_BYTES: usize = SIG_BYTES + 4 + 1; // we encode a 4-byte length at the beginning, along with a version byte
const FINGERPRINT_ENTRIES: usize = (1 << BITS) * (1 << BITS);
const FINGERPRINT_BYTES: usize = FINGERPRINT_ENTRIES * ::core::mem::size_of::<SigBucket>();
const TOTAL_BYTES: usize = 1 + 4 + FINGERPRINT_BYTES; // we encode a version byte and a 4-byte length at the beginning
// NOTE: This is not efficient if `SigBucket` is 1 byte (u8).
// NOTE: If `SigBucket` is changed to a u8, then the implementation
@ -80,42 +76,44 @@ impl Signature {
/// about the signature or the original file contents.
///
/// Do not use for any security-related purposes.
pub fn score_str<S: AsRef<str>>(&self, s: S) -> f64 {
pub fn score_str<S: AsRef<str>>(&self, other: S) -> f64 {
if self.0[0] != 0 {
panic!("unsupported signature version");
}
let original_length = u32::from_le_bytes(self.0[1..5].try_into().unwrap());
let s = s.as_ref();
let s_s: String = s.chars().filter(|&x| !char::is_whitespace(x)).collect();
let s = s_s.as_bytes();
if original_length < 2 || s.len() < 2 {
let original_length = u32::from_le_bytes(self.0[1..5].try_into().expect("invalid length"));
if original_length < 2 {
return 0.0;
}
let mut intersection_size = 0usize;
let other = other.as_ref();
let other_string: String = other.chars().filter(|&x| !char::is_whitespace(x)).collect();
let other = other_string.as_bytes();
let mut wb = self.bucket_iter().collect::<Vec<_>>();
if other.len() < 2 {
return 0.0;
}
for (b1, b2) in bigrams(s) {
let b1 = b1 >> SHIFT;
let b2 = b2 >> SHIFT;
let ix = ((b1 as usize) << BITS) | (b2 as usize);
if wb[ix] > 0 {
wb[ix] = wb[ix].saturating_sub(1);
intersection_size += 1;
let mut matching_bigrams: usize = 0;
let mut self_buckets = self.bucket_iter().collect::<Vec<_>>();
for (left, right) in bigrams(other) {
let left = left >> SHIFT;
let right = right >> SHIFT;
let index = ((left as usize) << BITS) | (right as usize);
if self_buckets[index] > 0 {
self_buckets[index] = self_buckets[index] - 1;
matching_bigrams += 1;
}
}
(2 * intersection_size) as f64 / (original_length as usize + s.len() - 2) as f64
(2 * matching_bigrams) as f64 / (original_length as usize + other.len() - 2) as f64
}
fn bucket_iter(&self) -> impl Iterator<Item = SigBucket> + '_ {
unsafe {
self.0[(TOTAL_BYTES - SIG_BYTES)..]
self.0[(TOTAL_BYTES - FINGERPRINT_BYTES)..]
.as_chunks_unchecked::<{ ::core::mem::size_of::<SigBucket>() }>()
.iter()
.map(|ch: &[u8; ::core::mem::size_of::<SigBucket>()]| SigBucket::from_le_bytes(*ch))
@ -125,45 +123,50 @@ impl Signature {
impl<S: AsRef<str>> From<S> for Signature {
#[inline]
fn from(s: S) -> Self {
let s = s.as_ref();
fn from(source: S) -> Self {
let source = source.as_ref();
let source_string: String = source
.chars()
.filter(|&x| !char::is_whitespace(x))
.collect();
let source = source_string.as_bytes();
let a_s: String = s.chars().filter(|&x| !char::is_whitespace(x)).collect();
let a = a_s.as_bytes();
let a_len: u32 = a
let source_len: u32 = source
.len()
.try_into()
.expect("strings with a byte-length above u32::MAX are not supported");
let mut a_res = [0; TOTAL_BYTES];
a_res[0] = 0; // version byte
a_res[1..5].copy_from_slice(&a_len.to_le_bytes()); // length
let mut bytes = [0; TOTAL_BYTES];
bytes[0] = 0; // version byte (0)
bytes[1..5].copy_from_slice(&source_len.to_le_bytes()); // next 4 bytes are the length
if a_len >= 2 {
let mut a_bigrams = [0 as SigBucket; SIG_ENTRIES];
if source_len >= 2 {
let mut buckets = [0 as SigBucket; FINGERPRINT_ENTRIES];
for (b1, b2) in bigrams(a) {
let b1 = b1 >> SHIFT;
let b2 = b2 >> SHIFT;
let encoded_bigram = ((b1 as usize) << BITS) | (b2 as usize);
a_bigrams[encoded_bigram] = a_bigrams[encoded_bigram].saturating_add(1);
for (left, right) in bigrams(source) {
let left = left >> SHIFT;
let right = right >> SHIFT;
let index = ((left as usize) << BITS) | (right as usize);
buckets[index] = buckets[index].saturating_add(1);
}
// NOTE: This is not efficient if `SigBucket` is 1 byte (u8).
let mut offset = TOTAL_BYTES - SIG_BYTES;
for bucket in a_bigrams {
let mut offset = TOTAL_BYTES - FINGERPRINT_BYTES;
for bucket in buckets {
let start = offset;
let end = start + ::core::mem::size_of::<SigBucket>();
a_res[start..end].copy_from_slice(&bucket.to_le_bytes());
bytes[start..end].copy_from_slice(&bucket.to_le_bytes());
offset = end;
}
}
Self(a_res)
Self(bytes)
}
}
/// Copies the passed bytes twice and zips them together with a one-byte offset.
/// This produces an iterator of the bigrams (pairs of consecutive bytes) in the input.
/// For example, the bigrams of 1, 2, 3, 4, 5 would be (1, 2), (2, 3), (3, 4), (4, 5).
#[inline]
fn bigrams(s: &[u8]) -> impl Iterator<Item = (u8, u8)> + '_ {
s.iter().copied().zip(s.iter().skip(1).copied())
@ -173,61 +176,47 @@ fn bigrams(s: &[u8]) -> impl Iterator<Item = (u8, u8)> + '_ {
mod tests {
use super::*;
macro_rules! assert_score {
($sig:ident, $s:expr, $e:expr) => {
let score = $sig.score_str($s);
if (score - $e).abs() >= 0.1 {
panic!(
"expected score of {} for string {:?}, got {}",
$e, $s, score
);
}
};
}
#[test]
fn score_signature() {
let sig = Signature::from("hello world");
macro_rules! assert_score {
($s:expr, $e:expr) => {
if (sig.score_str($s) - $e).abs() >= 0.1 {
panic!(
"expected score of {} for string {:?}, got {}",
$e,
$s,
sig.score_str($s)
);
}
};
}
// NOTE: The scores here are not exact, but are close enough
// to be useful for testing purposes, hence why some have the same
// "score" but different strings.
assert_score!("hello world", 1.0);
assert_score!("hello world!", 0.95);
assert_score!("hello world!!", 0.9);
assert_score!("hello world!!!", 0.85);
assert_score!("hello world!!!!", 0.8);
assert_score!("hello world!!!!!", 0.75);
assert_score!("hello world!!!!!!", 0.7);
assert_score!("hello world!!!!!!!", 0.65);
assert_score!("hello world!!!!!!!!", 0.62);
assert_score!("hello world!!!!!!!!!", 0.6);
assert_score!("hello world!!!!!!!!!!", 0.55);
assert_score!(sig, "hello world", 1.0);
assert_score!(sig, "hello world!", 0.95);
assert_score!(sig, "hello world!!", 0.9);
assert_score!(sig, "hello world!!!", 0.85);
assert_score!(sig, "hello world!!!!", 0.8);
assert_score!(sig, "hello world!!!!!", 0.75);
assert_score!(sig, "hello world!!!!!!", 0.7);
assert_score!(sig, "hello world!!!!!!!", 0.65);
assert_score!(sig, "hello world!!!!!!!!", 0.62);
assert_score!(sig, "hello world!!!!!!!!!", 0.6);
assert_score!(sig, "hello world!!!!!!!!!!", 0.55);
}
#[test]
fn score_ignores_whitespace() {
let sig = Signature::from("hello world");
macro_rules! assert_score {
($s:expr, $e:expr) => {
if (sig.score_str($s) - $e).abs() >= 0.1 {
panic!(
"expected score of {} for string {:?}, got {}",
$e,
$s,
sig.score_str($s)
);
}
};
}
assert_score!("hello world", 1.0);
assert_score!("hello world ", 1.0);
assert_score!("hello\nworld ", 1.0);
assert_score!("hello\n\tworld ", 1.0);
assert_score!("\t\t hel lo\n\two rld \t\t", 1.0);
assert_score!(sig, "hello world", 1.0);
assert_score!(sig, "hello world ", 1.0);
assert_score!(sig, "hello\nworld ", 1.0);
assert_score!(sig, "hello\n\tworld ", 1.0);
assert_score!(sig, "\t\t hel lo\n\two rld \t\t", 1.0);
}
const TEXT1: &str = include_str!("../fixture/text1.txt");