Fixed greedy generation method

This commit is contained in:
Guillaume B 2020-10-13 18:52:09 +02:00
parent 989bcec727
commit a22c89b010
4 changed files with 45 additions and 38 deletions

View File

@ -12,10 +12,17 @@
extern crate anyhow;
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
fn main() -> anyhow::Result<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let config = ConversationConfig {
do_sample: false,
num_beams: 3,
..Default::default()
};
let conversation_model = ConversationModel::new(config)?;
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =

View File

@ -717,7 +717,9 @@ impl ConversationModel {
.iter()
.position(|&r| r != pad_token)
.unwrap();
sequence_history.drain(sequence_history.len() - index_end + 1..);
if index_end > 0 {
sequence_history.drain(sequence_history.len() - index_end + 1..);
}
sequence_history.drain(..index_start);
removed_tokens.push((index_start, index_end));
}

View File

@ -2355,40 +2355,38 @@ pub(crate) mod private_generation_utils {
best_ids.push(best_hyp);
}
}
let decoded = if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min())
{
let sentence_max_length =
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length);
let decoded: Tensor = Tensor::ones(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
) * gen_opt.pad_token_id.unwrap();
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
let _ = decoded.get(hypothesis_index as i64).index_copy_(
let sentence_max_length =
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length);
let mut decoded = input_ids.new_empty(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
);
if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min()) {
let _ = decoded.fill_(
gen_opt
.pad_token_id
.unwrap_or(gen_opt.eos_token_ids.as_ref().unwrap()[0]),
);
}
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
let _ = decoded.get(hypothesis_index as i64).index_copy_(
0,
&Tensor::arange1(
0,
&Tensor::arange1(
0,
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
),
&best_id,
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
),
&best_id,
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < gen_opt.max_length {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
gen_opt.eos_token_ids.as_ref().unwrap()[0],
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < gen_opt.max_length {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
gen_opt.eos_token_ids.as_ref().unwrap()[0],
);
}
}
decoded
} else {
Tensor::stack(&best_ids, 0)
.to_kind(Int64)
.to(input_ids.device())
};
}
decoded
}

View File

@ -737,13 +737,13 @@ impl ZeroShotClassificationModel {
let mut output_labels = vec![];
for sentence_idx in 0..num_inputs {
let mut sentence_labels = vec![];
let sentence_scores = scores
for (label_index, score) in scores
.select(0, sentence_idx as i64)
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>();
for (label_index, score) in sentence_scores.into_iter().enumerate() {
.enumerate()
{
let label_string = labels[label_index].to_string();
let label = Label {
text: label_string,