Updated generation (#14)

This commit is contained in:
guillaume-be 2020-03-16 21:45:20 +01:00 committed by GitHub
parent 7c6e71155a
commit 3ccfc05d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.5.0"
version = "0.5.1"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
default-run = "rust-bert"

View File

@ -476,7 +476,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
for batch_index in 0..batch_size {
if done[batch_index as usize] {
assert!(hypotheses[batch_index as usize].len() > num_beams,
assert!(hypotheses[batch_index as usize].len() >= num_beams,
"Batch cannot be completed if all beams have not been generated");
assert!(eos_token_ids.is_some() & pad_token_id.is_some(),
"EOS and Padding tokens need to be defined if the number of generated \
@ -603,7 +603,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
if sentence_length < max_length {
let _ = decoded
.get(hypothesis_index as i64)
.index_fill_(0, &Tensor::of_slice(&[sentence_length]), eos_token_ids.as_ref().unwrap()[0]);
.index_fill_(0, &Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()), eos_token_ids.as_ref().unwrap()[0]);
}
}
decoded