mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-19 16:48:02 +03:00
Update scores for padding index
This commit is contained in:
parent
d6b5abe947
commit
ff2e9f2581
@ -657,9 +657,8 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
|
||||
fn get_decoder_start_id(&self) -> Option<i64> { self.decoder_start_id }
|
||||
|
||||
fn prepare_scores_for_generation(&self, scores: &mut Tensor, current_length: i64, max_length: i64) {
|
||||
if current_length == 1 {
|
||||
self.force_token_id_generation(scores, &vec!(self.get_bos_id().unwrap()));
|
||||
} else if current_length == max_length - 1 {
|
||||
let _ = scores.index_fill_(1, &Tensor::of_slice(&[self.get_pad_id().unwrap()]).to_kind(Int64).to_device(scores.device()), std::f64::NEG_INFINITY);
|
||||
if current_length == max_length - 1 {
|
||||
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user