Working MobileBert masked LM

This commit is contained in:
Guillaume B 2020-12-19 15:46:33 +01:00
parent 9845ce199b
commit bf00f90d55
2 changed files with 19 additions and 5 deletions

View File

@ -78,8 +78,24 @@ fn main() -> anyhow::Result<()> {
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2); // Outputs "pear" : "It was a very nice and [pleasant] day"
println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!(
"score: {}",
model_output
.logits
.get(0)
.get(4)
.double_value(&[i64::from(&index_1)])
); // 10.0558
println!("{}", word_2); // Outputs "sunny" : "It was a very nice and [sunny] day"
println!(
"score: {}",
model_output
.logits
.get(1)
.get(7)
.double_value(&[i64::from(&index_2)])
); // 14.2708
Ok(())
}

View File

@ -180,9 +180,7 @@ impl MobileBertPredictionHeadTransform {
let activation_function = config.hidden_act.get_function();
let layer_norm = NormalizationLayer::new(
p / "LayerNorm",
config
.normalization_type
.unwrap_or(NormalizationType::no_norm),
NormalizationType::layer_norm,
config.hidden_size,
config.layer_norm_eps,
);