mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Update punctuation POS tags with low score
This commit is contained in:
parent
02819c0a71
commit
65da7afbb6
@ -20,8 +20,8 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Define input
|
||||
let input = [
|
||||
"My name is Amélie. I live in Москва.",
|
||||
"Chongqing is a city in China.",
|
||||
"My name is Amélie. My email is amelie@somemail.com.",
|
||||
"A liter of milk costs 0.95 Euros!",
|
||||
];
|
||||
|
||||
// Run model
|
||||
|
@ -133,7 +133,15 @@ impl POSModel {
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
.into_iter()
|
||||
.filter(|token| token.label != "O")
|
||||
.map(|mut token| {
|
||||
if (Self::is_punctuation(token.text.as_str()))
|
||||
& ((token.score < 0.5) | token.score.is_nan())
|
||||
{
|
||||
token.label = String::from(".");
|
||||
token.score = 1f64;
|
||||
};
|
||||
token
|
||||
})
|
||||
.map(|token| POSTag {
|
||||
word: token.text,
|
||||
score: token.score,
|
||||
@ -141,6 +149,10 @@ impl POSModel {
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn is_punctuation(string: &str) -> bool {
|
||||
string.chars().all(|c| c.is_ascii_punctuation())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
Loading…
Reference in New Issue
Block a user