diff --git a/Cargo.toml b/Cargo.toml index c770e25..bc33670 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.5.0" +version = "0.5.1" authors = ["Guillaume Becquin "] edition = "2018" default-run = "rust-bert" diff --git a/src/pipelines/generation.rs b/src/pipelines/generation.rs index 0a682e6..c1385a6 100644 --- a/src/pipelines/generation.rs +++ b/src/pipelines/generation.rs @@ -476,7 +476,7 @@ pub trait LanguageGenerator> { for batch_index in 0..batch_size { if done[batch_index as usize] { - assert!(hypotheses[batch_index as usize].len() > num_beams, + assert!(hypotheses[batch_index as usize].len() >= num_beams, "Batch cannot be completed if all beams have not been generated"); assert!(eos_token_ids.is_some() & pad_token_id.is_some(), "EOS and Padding tokens need to be defined if the number of generated \ @@ -603,7 +603,7 @@ pub trait LanguageGenerator> { if sentence_length < max_length { let _ = decoded .get(hypothesis_index as i64) - .index_fill_(0, &Tensor::of_slice(&[sentence_length]), eos_token_ids.as_ref().unwrap()[0]); + .index_fill_(0, &Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()), eos_token_ids.as_ref().unwrap()[0]); } } decoded