Update punctuation POS tags with low score

This commit is contained in:
Guillaume B 2021-03-15 16:41:00 +01:00
parent 02819c0a71
commit 65da7afbb6
2 changed files with 15 additions and 3 deletions

View File

@ -20,8 +20,8 @@ fn main() -> anyhow::Result<()> {
// Define input
let input = [
"My name is Amélie. I live in Москва.",
"Chongqing is a city in China.",
"My name is Amélie. My email is amelie@somemail.com.",
"A liter of milk costs 0.95 Euros!",
];
// Run model

View File

@ -133,7 +133,15 @@ impl POSModel {
self.token_classification_model
.predict(input, true, false)
.into_iter()
.filter(|token| token.label != "O")
.map(|mut token| {
if (Self::is_punctuation(token.text.as_str()))
& ((token.score < 0.5) | token.score.is_nan())
{
token.label = String::from(".");
token.score = 1f64;
};
token
})
.map(|token| POSTag {
word: token.text,
score: token.score,
@ -141,6 +149,10 @@ impl POSModel {
})
.collect()
}
fn is_punctuation(string: &str) -> bool {
string.chars().all(|c| c.is_ascii_punctuation())
}
}
#[cfg(test)]