mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-20 00:57:43 +03:00
Updated generation (#14)
This commit is contained in:
parent
7c6e71155a
commit
3ccfc05d92
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.5.0"
|
||||
version = "0.5.1"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
default-run = "rust-bert"
|
||||
|
@ -476,7 +476,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
|
||||
|
||||
for batch_index in 0..batch_size {
|
||||
if done[batch_index as usize] {
|
||||
assert!(hypotheses[batch_index as usize].len() > num_beams,
|
||||
assert!(hypotheses[batch_index as usize].len() >= num_beams,
|
||||
"Batch cannot be completed if all beams have not been generated");
|
||||
assert!(eos_token_ids.is_some() & pad_token_id.is_some(),
|
||||
"EOS and Padding tokens need to be defined if the number of generated \
|
||||
@ -603,7 +603,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
|
||||
if sentence_length < max_length {
|
||||
let _ = decoded
|
||||
.get(hypothesis_index as i64)
|
||||
.index_fill_(0, &Tensor::of_slice(&[sentence_length]), eos_token_ids.as_ref().unwrap()[0]);
|
||||
.index_fill_(0, &Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()), eos_token_ids.as_ref().unwrap()[0]);
|
||||
}
|
||||
}
|
||||
decoded
|
||||
|
Loading…
Reference in New Issue
Block a user