Updated generation (clippy)

This commit is contained in:
Guillaume B 2020-09-13 11:08:49 +02:00
parent 5aa1e635ba
commit 6cef45787d
16 changed files with 494 additions and 428 deletions

View File

@ -66,12 +66,12 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print model predictions
for (position, token) in tokenized_input[0].token_ids.iter().enumerate() {
let probability = output.double_value(&[position as i64]);
let probability = model_output.probabilities.double_value(&[position as i64]);
let generated = if probability > 0.5 {
"generated"
} else {

View File

@ -69,12 +69,20 @@ fn main() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

View File

@ -70,4 +70,5 @@ pub use bart_model::{
BartVocabResources,
};
pub(crate) use bart_model::BartModelOutput;
pub(crate) use encoder::BartEncoderOutput;

View File

@ -212,7 +212,7 @@ impl ElectraModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(
/// Some(input_tensor),
@ -233,7 +233,7 @@ impl ElectraModel {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<ElectraModelOutput, &'static str> {
let (input_shape, device) = match &input_ids {
Some(input_value) => match &input_embeds {
Some(_) => {
@ -288,7 +288,11 @@ impl ElectraModel {
train,
);
Ok((hidden_state, all_hidden_states, all_attentions))
Ok(ElectraModelOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
@ -590,7 +594,7 @@ impl ElectraForMaskedLM {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model.forward_t(
/// Some(input_tensor),
/// Some(mask),
@ -609,8 +613,8 @@ impl ElectraForMaskedLM {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraMaskedLMOutput {
let model_output = self
.electra
.forward_t(
input_ids,
@ -621,9 +625,13 @@ impl ElectraForMaskedLM {
train,
)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
(hidden_states, all_hidden_states, all_attentions)
let hidden_states = self.generator_head.forward(&model_output.hidden_state);
let prediction_scores = hidden_states.apply(&self.lm_head);
ElectraMaskedLMOutput {
prediction_scores,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -712,7 +720,7 @@ impl ElectraDiscriminator {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -730,8 +738,8 @@ impl ElectraDiscriminator {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraDiscriminatorOutput {
let model_output = self
.electra
.forward_t(
input_ids,
@ -742,8 +750,15 @@ impl ElectraDiscriminator {
train,
)
.unwrap();
let probabilities = self.discriminator_head.forward(&hidden_states).sigmoid();
(probabilities, all_hidden_states, all_attentions)
let probabilities = self
.discriminator_head
.forward(&model_output.hidden_state)
.sigmoid();
ElectraDiscriminatorOutput {
probabilities,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
@ -845,7 +860,7 @@ impl ElectraForTokenClassification {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// electra_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
@ -863,8 +878,8 @@ impl ElectraForTokenClassification {
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool,
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states, all_hidden_states, all_attentions) = self
) -> ElectraTokenClassificationOutput {
let model_output = self
.electra
.forward_t(
input_ids,
@ -875,9 +890,38 @@ impl ElectraForTokenClassification {
train,
)
.unwrap();
let output = hidden_states
let logits = model_output
.hidden_state
.apply_t(&self.dropout, train)
.apply(&self.classifier);
(output, all_hidden_states, all_attentions)
ElectraTokenClassificationOutput {
logits,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
}
}
}
pub struct ElectraModelOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraDiscriminatorOutput {
pub probabilities: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraMaskedLMOutput {
pub prediction_scores: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}
pub struct ElectraTokenClassificationOutput {
pub logits: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -128,7 +128,7 @@ impl Attention {
fn flatten(&self, x: Tensor) -> Tensor {
x.transpose(1, 2)
.contiguous()
.view((x.size()[0], -1, &self.n_head * self.dim_per_head))
.view((x.size()[0], -1, self.n_head * self.dim_per_head))
}
fn attention(
@ -141,7 +141,7 @@ impl Attention {
) -> (Tensor, Option<Tensor>) {
let mut w = query.matmul(&key);
if self.scale {
w = w / (*value.size().last().unwrap() as f64).sqrt();
w /= (*value.size().last().unwrap() as f64).sqrt();
}
let (nd, ns) = (w.size()[2], w.size()[3]);
@ -149,7 +149,7 @@ impl Attention {
let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask {
w = w + mask;
w += mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&value);

View File

@ -352,7 +352,7 @@ impl Gpt2Model {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, past, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt2_model
/// .forward_t(
/// &Some(input_tensor),
@ -375,15 +375,7 @@ impl Gpt2Model {
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<
(
Tensor,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
),
&'static str,
> {
) -> Result<Gpt2ModelOutput, &'static str> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
@ -466,34 +458,29 @@ impl Gpt2Model {
None
};
let mut layer_iter = self.h.iter().zip(layer_past);
loop {
match layer_iter.next() {
Some(layer_values) => {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let layer_iter = self.h.iter().zip(layer_past);
for layer_values in layer_iter {
let (layer, past) = layer_values;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, &past, &attention_mask, train);
hidden_state = temp.0;
if let Some(presents) = all_presents.borrow_mut() {
presents.push(temp.1.as_ref().copy());
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.2.as_ref().unwrap().copy());
};
}
Ok((
hidden_state.apply(&self.ln_f),
all_presents,
Ok(Gpt2ModelOutput {
hidden_state: hidden_state.apply(&self.ln_f),
cache: all_presents,
all_hidden_states,
all_attentions,
))
})
}
}
@ -636,7 +623,7 @@ impl LMHeadModel for GPT2LMHeadModel {
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<LMModelOutput, &'static str> {
let (output, past, all_hidden_states, all_attentions) = match layer_past {
let model_output = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
input_ids,
&layer_past,
@ -658,13 +645,20 @@ impl LMHeadModel for GPT2LMHeadModel {
_ => Err("Cache not compatible with GPT2 model"),
}?;
let lm_logits = output.apply(&self.lm_head);
let lm_logits = model_output.hidden_state.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
encoder_hidden_state: None,
cache: Cache::GPT2Cache(past),
all_hidden_states,
all_attentions,
cache: Cache::GPT2Cache(model_output.cache),
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
pub struct Gpt2ModelOutput {
pub hidden_state: Tensor,
pub cache: Option<Vec<Tensor>>,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -11,7 +11,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartConfig, BartEncoderOutput, BartModel, LayerState};
use crate::bart::{BartConfig, BartEncoderOutput, BartModel, BartModelOutput, LayerState};
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
use std::borrow::Borrow;
use tch::nn::Init;
@ -325,15 +325,7 @@ impl MarianForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -354,15 +346,7 @@ impl MarianForConditionalGeneration {
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
) -> BartModelOutput {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
@ -376,20 +360,14 @@ impl MarianForConditionalGeneration {
let lm_logits = base_model_output
.decoder_hidden_state
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(
lm_logits,
base_model_output.encoder_hidden_state,
base_model_output.cache,
base_model_output.all_decoder_hidden_states,
base_model_output.all_decoder_attentions,
base_model_output.all_encoder_hidden_states,
base_model_output.all_encoder_attentions,
)
BartModelOutput {
decoder_hidden_state: lm_logits,
..base_model_output
}
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let encoder_hidden_states = self
.base_model
self.base_model
.encoder
.forward_t(
input_ids,
@ -397,8 +375,7 @@ impl MarianForConditionalGeneration {
&self.base_model.embeddings,
false,
)
.hidden_state;
encoder_hidden_states
.hidden_state
}
}
@ -450,15 +427,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (
/// decoder_output,
/// encoder_hidden_states,
/// cache,
/// all_encoder_hidden_states,
/// all_encoder_attentions,
/// all_decoder_hidden_states,
/// all_decoder_attentions,
/// ) = no_grad(|| {
/// let model_output = no_grad(|| {
/// marian_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
@ -509,7 +478,7 @@ impl LMHeadModel for MarianForConditionalGeneration {
None,
train,
),
_ => Err("Cache not compatible with Marian Model")?,
_ => return Err("Cache not compatible with Marian Model"),
};
let lm_logits = base_model_output

View File

@ -192,7 +192,7 @@ impl OpenAiGptModel {
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
/// .expand(&[batch_size, sequence_length], true);
///
/// let (output, hidden_states, attentions) = no_grad(|| {
/// let model_output = no_grad(|| {
/// gpt_model
/// .forward_t(
/// &Some(input_tensor),
@ -213,7 +213,7 @@ impl OpenAiGptModel {
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
train: bool,
) -> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
) -> Result<OpenAiGptOutput, &'static str> {
let (input_embeddings, seq_length) = match input_ids {
Some(input_value) => match input_embeds {
Some(_) => {
@ -267,25 +267,23 @@ impl OpenAiGptModel {
None
};
let mut layers = self.h.iter();
loop {
match layers.next() {
Some(layer) => {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
for layer in &self.h {
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy());
};
let temp = layer.forward_t(&hidden_state, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
None => break,
let temp = layer.forward_t(&hidden_state, &attention_mask, train);
hidden_state = temp.0;
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(temp.1.as_ref().unwrap().copy());
};
}
Ok((hidden_state, all_hidden_states, all_attentions))
Ok(OpenAiGptOutput {
hidden_state,
all_hidden_states,
all_attentions,
})
}
}
@ -413,7 +411,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
_decoder_input_ids: &Option<Tensor>,
train: bool,
) -> Result<LMModelOutput, &'static str> {
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
let model_output = self.transformer.forward_t(
input_ids,
attention_mask,
token_type_ids,
@ -422,13 +420,19 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
train,
)?;
let lm_logits = output.apply(&self.lm_head);
let lm_logits = model_output.hidden_state.apply(&self.lm_head);
Ok(LMModelOutput {
lm_logits,
encoder_hidden_state: None,
cache: Cache::None,
all_hidden_states,
all_attentions,
all_hidden_states: model_output.all_hidden_states,
all_attentions: model_output.all_attentions,
})
}
}
pub struct OpenAiGptOutput {
pub hidden_state: Tensor,
pub all_hidden_states: Option<Vec<Tensor>>,
pub all_attentions: Option<Vec<Tensor>>,
}

View File

@ -16,7 +16,6 @@
//! generic pipelines. The model component is defined in the generic pipeline itself as the
//! pre-processing, forward pass and postprocessing differs between pipelines while basic config and
//! tokenization objects don't.
//!
use crate::albert::AlbertConfig;
use crate::bart::BartConfig;
use crate::bert::BertConfig;
@ -335,87 +334,91 @@ impl TokenizerOption {
original_offsets_2: Option<Vec<Vec<OffsetSize>>>,
mask_1: Vec<Mask>,
mask_2: Option<Vec<Mask>>,
) -> (
Vec<i64>,
Vec<i8>,
Vec<i8>,
Vec<Option<Offset>>,
Vec<Vec<OffsetSize>>,
Vec<Mask>,
) {
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
) -> TokenizedInput {
let (token_ids, segment_ids, special_tokens_mask, token_offsets, reference_offsets, mask) =
match *self {
Self::Bert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Roberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::T5(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Albert(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
};
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets,
reference_offsets,
mask,
}
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids(&self, tokens: &Vec<String>) -> Vec<i64> {
pub fn convert_tokens_to_ids(&self, tokens: &[String]) -> Vec<i64> {
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(&tokens.into()),
}
}

View File

@ -70,21 +70,21 @@ pub struct ConversationConfig {
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Minimum free length available for generated responses (default: 32)
pub min_length_for_response: u64,
pub min_length_for_response: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -92,9 +92,9 @@ pub struct ConversationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
@ -306,7 +306,7 @@ impl Conversation {
pub fn get_last_input(&self) -> Option<&str> {
if self.new_user_input.is_some() {
Some(self.new_user_input.as_ref().unwrap().as_str())
} else if self.past_user_inputs.len() > 0 {
} else if !self.past_user_inputs.is_empty() {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
None
@ -564,12 +564,18 @@ impl ConversationManager {
}
}
impl Default for ConversationManager {
fn default() -> Self {
Self::new()
}
}
/// # Conversation model
/// Processes a ConversationManager and generate system responses for active conversations.
pub struct ConversationModel {
model: GPT2Generator,
eos_token_id: i64,
max_allowed_context_length: u64,
max_allowed_context_length: i64,
}
impl ConversationModel {
@ -615,7 +621,7 @@ impl ConversationModel {
let model = GPT2Generator::new(generate_config)?;
let eos_token_id = *model.get_eos_ids().as_ref().unwrap().first().unwrap();
let max_allowed_length =
conversation_config.max_length as u64 - conversation_config.min_length_for_response;
conversation_config.max_length - conversation_config.min_length_for_response;
Ok(ConversationModel {
model,
eos_token_id,

View File

@ -73,7 +73,9 @@ use crate::openai_gpt::{
OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptMergesResources,
OpenAiGptModelResources, OpenAiGptVocabResources,
};
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation::private_generation_utils::{
GenerateOptions, PrivateLanguageGenerator,
};
use crate::t5::{
LayerState as T5LayerState, T5Config, T5ConfigResources, T5ForConditionalGeneration,
T5ModelResources, T5VocabResources,
@ -104,19 +106,19 @@ pub struct GenerateConfig {
/// Merges resource (default: pretrained GPT2 model)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -124,9 +126,9 @@ pub struct GenerateConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
@ -179,11 +181,11 @@ impl GenerateConfig {
"length_penalty must be strictly greater than 0"
);
assert!(
self.num_return_sequences > 0u64,
self.num_return_sequences > 0i64,
"num_return_sequences must be strictly greater than 0"
);
assert!(
self.num_beams > 0u64,
self.num_beams > 0i64,
"num_beams must be strictly greater than 0"
);
@ -245,8 +247,8 @@ impl OpenAIGenerator {
generate_config.validate();
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptModelResources::GPT,
@ -255,8 +257,8 @@ impl OpenAIGenerator {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptConfigResources::GPT,
@ -265,8 +267,8 @@ impl OpenAIGenerator {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptVocabResources::GPT,
@ -275,8 +277,8 @@ impl OpenAIGenerator {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(
OpenAiGptMergesResources::GPT,
@ -575,32 +577,32 @@ impl BartGenerator {
/// ```
pub fn new(generate_config: GenerateConfig) -> Result<BartGenerator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartModelResources::BART))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartConfigResources::BART))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartVocabResources::BART))
} else {
generate_config.vocab_resource.clone()
};
let merges_resource = if &generate_config.merges_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
let merges_resource = if generate_config.merges_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(BartMergesResources::BART))
} else {
@ -702,7 +704,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
max_length: i64,
) {
if current_length == 1 {
self.force_token_id_generation(scores, &vec![self.get_bos_id().unwrap()]);
self.force_token_id_generation(scores, &[self.get_bos_id().unwrap()]);
} else if current_length == max_length - 1 {
self.force_token_id_generation(scores, self.get_eos_ids().as_ref().unwrap());
}
@ -747,7 +749,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -797,7 +799,7 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1022,7 +1024,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -1072,7 +1074,7 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1119,24 +1121,24 @@ pub struct T5Generator {
impl T5Generator {
pub fn new(generate_config: GenerateConfig) -> Result<T5Generator, RustBertError> {
// The following allow keeping the same GenerationConfig Default for GPT, GPT2 and BART models
let model_resource = if &generate_config.model_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
let model_resource = if generate_config.model_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL))
} else {
generate_config.model_resource.clone()
};
let config_resource = if &generate_config.config_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
let config_resource = if generate_config.config_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL))
} else {
generate_config.config_resource.clone()
};
let vocab_resource = if &generate_config.vocab_resource
== &Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
let vocab_resource = if generate_config.vocab_resource
== Resource::Remote(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2))
{
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL))
} else {
@ -1254,7 +1256,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().encode_list(
@ -1304,7 +1306,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
match past {
Cache::T5Cache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
@ -1351,6 +1353,23 @@ pub(crate) mod private_generation_utils {
use tch::kind::Kind::{Bool, Float, Int64};
use tch::{nn, Device, Tensor};
pub struct GenerateOptions {
pub min_length: i64,
pub max_length: i64,
pub do_sample: bool,
pub temperature: f64,
pub top_k: i64,
pub top_p: f64,
pub repetition_penalty: f64,
pub no_repeat_ngram_size: i64,
pub pad_token_id: Option<i64>,
pub eos_token_ids: Option<Vec<i64>>,
pub num_return_sequences: i64,
pub early_stopping: bool,
pub num_beams: i64,
pub length_penalty: f64,
}
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&self) -> &T;
fn get_tokenizer(&self) -> &U;
@ -1394,7 +1413,7 @@ pub(crate) mod private_generation_utils {
fn encode_prompt_text(
&self,
prompt_text: Vec<&str>,
max_len: u64,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor {
let tokens = self.get_tokenizer().tokenize_list(prompt_text);
@ -1460,7 +1479,7 @@ pub(crate) mod private_generation_utils {
&self,
next_token_logits: &mut Tensor,
batch_size: i64,
num_beams: u64,
num_beams: i64,
prev_output_tokens: &Tensor,
repetition_penalty: f64,
) {
@ -1469,7 +1488,7 @@ pub(crate) mod private_generation_utils {
let token = prev_output_tokens.get(i).int64_value(&[token_position]);
let updated_value = &next_token_logits.double_value(&[i, token]);
if updated_value < &0f64 {
&next_token_logits.get(i).index_fill_(
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
@ -1477,7 +1496,7 @@ pub(crate) mod private_generation_utils {
updated_value * repetition_penalty,
);
} else {
&next_token_logits.get(i).index_fill_(
let _ = next_token_logits.get(i).index_fill_(
0,
&Tensor::of_slice(&[token])
.to_kind(Int64)
@ -1521,11 +1540,10 @@ pub(crate) mod private_generation_utils {
let ngram = &hypothesis_input_ids[ngram.0 as usize..ngram.1 as usize + 1];
let key = ngram[..no_repeat_ngram_size as usize - 1].to_vec();
let value = *ngram.last().unwrap();
if generated_ngram.contains_key(&key) {
generated_ngram.get_mut(&key).unwrap().push(value)
} else {
generated_ngram.insert(key, vec![value]);
}
generated_ngram
.entry(key)
.or_insert_with(|| vec![value])
.push(value);
}
let hypothesis_banned_tokens = match generated_ngram.get(query) {
Some(banned_tokens) => banned_tokens.clone(),
@ -1551,21 +1569,20 @@ pub(crate) mod private_generation_utils {
let top_k = vocab_size - min(max(top_k, min_tokens_to_keep), vocab_size);
let (_, indices_to_remove) = logits.topk(top_k, -1, false, false);
for index in 0..*logits.size().first().unwrap() {
&logits.get(index).index_fill_(
let _ = logits.get(index).index_fill_(
0,
&indices_to_remove.get(index),
std::f64::NEG_INFINITY,
);
}
}
if top_p < 1f64 {
let (sorted_logits, sorted_indices) = logits.sort(-1, true);
let cumulative_probabilities = sorted_logits.softmax(-1, Float).cumsum(-1, Float);
let mut sorted_indices_to_remove =
cumulative_probabilities.ge(top_p).to_kind(Int64);
if min_tokens_to_keep > 1 {
&sorted_indices_to_remove.index_fill_(
let _ = sorted_indices_to_remove.index_fill_(
1,
&Tensor::arange1(0, min_tokens_to_keep + 1, (Int64, logits.device())),
0,
@ -1597,31 +1614,22 @@ pub(crate) mod private_generation_utils {
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
attention_mask: Tensor,
gen_opt: GenerateOptions,
) -> Tensor {
let mut unfinished_sentences =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()));
let mut sentence_lengths: Tensor =
Tensor::ones(&[batch_size], (Int64, self.get_var_store().device()))
* max_length as i64;
* gen_opt.max_length as i64;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
while current_length < max_length {
while current_length < gen_opt.max_length {
let (
prepared_input,
prepared_attention_mask,
@ -1653,26 +1661,26 @@ pub(crate) mod private_generation_utils {
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
if gen_opt.repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
gen_opt.repetition_penalty,
)
}
// Get banned tokens and set their probability to 0
if no_repeat_ngram_size > 0 {
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
gen_opt.no_repeat_ngram_size as i64,
current_length as i64,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&next_token_logits.get(batch_index).index_fill_(
let _ = next_token_logits.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
@ -1682,21 +1690,26 @@ pub(crate) mod private_generation_utils {
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&next_token_logits.index_fill_(
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = next_token_logits.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap())
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
// Top-k and top-p sampling
let next_token = if do_sample {
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
let next_token = if gen_opt.do_sample {
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
self.top_k_top_p_filtering(&mut next_token_logits, top_k as i64, top_p, 1);
self.top_k_top_p_filtering(
&mut next_token_logits,
gen_opt.top_k as i64,
gen_opt.top_p,
1,
);
let probabilities = next_token_logits.softmax(-1, Float);
probabilities.multinomial(1, false).squeeze1(1)
} else {
@ -1704,17 +1717,17 @@ pub(crate) mod private_generation_utils {
};
// Add tokens to unfinished sentences
let tokens_to_add = match &eos_token_ids {
let tokens_to_add = match &gen_opt.eos_token_ids {
Some(_) => {
next_token * &unfinished_sentences
- pad_token_id.unwrap() * (&unfinished_sentences - 1)
- gen_opt.pad_token_id.unwrap() * (&unfinished_sentences - 1)
}
None => next_token,
};
input_ids = Tensor::cat(&[input_ids, tokens_to_add.unsqueeze(-1)], -1);
if eos_token_ids.is_some() {
for eos_token_id in eos_token_ids.as_ref().unwrap() {
if gen_opt.eos_token_ids.is_some() {
for eos_token_id in gen_opt.eos_token_ids.as_ref().unwrap() {
let sentence_with_eos = tokens_to_add.eq(*eos_token_id).to_kind(Int64);
let sentence_with_eos: Tensor = sentence_with_eos * &unfinished_sentences;
let _ = sentence_lengths.masked_fill_(
@ -1746,7 +1759,7 @@ pub(crate) mod private_generation_utils {
}
let decoded = if i64::from(&sentence_lengths.min().ne1(&sentence_lengths.max())) > 0 {
match pad_token_id {
match gen_opt.pad_token_id {
Some(pad_value) => {
let decoded: Tensor = Tensor::ones(
&[batch_size, i64::from(sentence_lengths.max())],
@ -1783,33 +1796,27 @@ pub(crate) mod private_generation_utils {
input_ids: Tensor,
encoder_outputs: Option<Tensor>,
cur_len: i64,
min_length: i64,
max_length: i64,
do_sample: bool,
early_stopping: bool,
temperature: f64,
top_k: i64,
top_p: f64,
repetition_penalty: f64,
no_repeat_ngram_size: i64,
pad_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
batch_size: i64,
num_return_sequences: i64,
length_penalty: f64,
num_beams: i64,
attention_mask: Tensor,
gen_opt: GenerateOptions,
) -> Tensor {
let mut hypotheses = (0..batch_size)
.map(|_| BeamHypotheses::new(num_beams, max_length, length_penalty, early_stopping))
.map(|_| {
BeamHypotheses::new(
gen_opt.num_beams,
gen_opt.max_length,
gen_opt.length_penalty,
gen_opt.early_stopping,
)
})
.collect::<Vec<BeamHypotheses>>();
let vocab_size = self.get_vocab_size();
let beam_scores = Tensor::zeros(
&[batch_size, num_beams],
&[batch_size, gen_opt.num_beams],
(Float, self.get_var_store().device()),
);
if !do_sample {
if !gen_opt.do_sample {
let _ = beam_scores
.slice(1, 1, *beam_scores.size().last().unwrap(), 1)
.fill_(-1e9);
@ -1827,7 +1834,7 @@ pub(crate) mod private_generation_utils {
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;
while current_length < max_length {
while current_length < gen_opt.max_length {
let (
prepared_input,
prepared_attention_mask,
@ -1860,42 +1867,47 @@ pub(crate) mod private_generation_utils {
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
if gen_opt.repetition_penalty > 1f64 {
self.enforce_repetition_penalty(
&mut next_token_logits,
batch_size,
1,
&input_ids,
repetition_penalty,
gen_opt.repetition_penalty,
)
}
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
if gen_opt.temperature > 1f64 {
next_token_logits /= gen_opt.temperature;
}
let mut scores = next_token_logits.log_softmax(-1, Float);
if self.is_encoder_decoder() & !do_sample {
self.prepare_scores_for_generation(&mut scores, current_length, max_length);
if self.is_encoder_decoder() & !gen_opt.do_sample {
self.prepare_scores_for_generation(
&mut scores,
current_length,
gen_opt.max_length,
);
}
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
let _ = scores.index_fill_(
1,
&Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()),
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(scores.device()),
std::f64::NEG_INFINITY,
);
}
// Get banned tokens and set their probability to 0
if no_repeat_ngram_size > 0 {
if gen_opt.no_repeat_ngram_size > 0 {
let banned_tokens = self.get_banned_tokens(
&input_ids,
no_repeat_ngram_size as i64,
current_length as i64,
gen_opt.no_repeat_ngram_size,
current_length,
);
for (batch_index, index_banned_token) in
(0..banned_tokens.len() as i64).zip(banned_tokens)
{
&scores.get(batch_index).index_fill_(
let _ = scores.get(batch_index).index_fill_(
0,
&Tensor::of_slice(&index_banned_token)
.to_device(next_token_logits.device()),
@ -1904,16 +1916,16 @@ pub(crate) mod private_generation_utils {
}
}
let (next_scores, next_tokens) = if do_sample {
let (next_scores, next_tokens) = if gen_opt.do_sample {
let mut _scores: Tensor =
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
self.top_k_top_p_filtering(&mut _scores, gen_opt.top_k, gen_opt.top_p, 2);
let _scores = _scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
.view((batch_size, gen_opt.num_beams * vocab_size));
let probabilities = _scores.softmax(-1, Float);
let next_tokens = probabilities.multinomial(2 * num_beams, false);
let next_tokens = probabilities.multinomial(2 * gen_opt.num_beams, false);
let next_scores = _scores.gather(-1, &next_tokens, false);
let (next_scores, next_scores_indices) = next_scores.sort(1, true);
let next_tokens = next_tokens.gather(-1, &next_scores_indices, false);
@ -1923,25 +1935,25 @@ pub(crate) mod private_generation_utils {
&scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
let next_scores = next_scores
.contiguous()
.view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
.view((batch_size, gen_opt.num_beams * vocab_size));
next_scores.topk(2 * gen_opt.num_beams, 1, true, true)
};
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec![];
for batch_index in 0..batch_size {
if done[batch_index as usize] {
assert!(
hypotheses[batch_index as usize].len() >= num_beams,
hypotheses[batch_index as usize].len() >= gen_opt.num_beams,
"Batch cannot be completed if all beams have not been generated"
);
assert!(
eos_token_ids.is_some() & pad_token_id.is_some(),
gen_opt.eos_token_ids.is_some() & gen_opt.pad_token_id.is_some(),
"EOS and Padding tokens need to be defined if the number of generated \
beams is greater than the target number fo beams"
);
next_batch_beam.append(
&mut (0..num_beams)
.map(|_| (0f64, pad_token_id.unwrap(), 0i64))
&mut (0..gen_opt.num_beams)
.map(|_| (0f64, gen_opt.pad_token_id.unwrap(), 0i64))
.collect::<Vec<(f64, i64, i64)>>(),
);
continue;
@ -1960,11 +1972,11 @@ pub(crate) mod private_generation_utils {
let beam_id = beam_token_id / vocab_size;
let token_id = beam_token_id % vocab_size;
let effective_beam_id = batch_index * num_beams + beam_id;
let effective_beam_id = batch_index * gen_opt.num_beams + beam_id;
if eos_token_ids.as_ref().is_some() {
if eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank >= num_beams {
if gen_opt.eos_token_ids.as_ref().is_some() {
if gen_opt.eos_token_ids.as_ref().unwrap().contains(&token_id) {
if beam_token_rank >= gen_opt.num_beams {
beam_token_rank += 1;
continue;
}
@ -1985,7 +1997,7 @@ pub(crate) mod private_generation_utils {
));
}
if (next_sentence_beam.len() as i64 == num_beams)
if (next_sentence_beam.len() as i64 == gen_opt.num_beams)
| (beam_token_rank == beam_token_rank_max_value)
{
break;
@ -1993,15 +2005,14 @@ pub(crate) mod private_generation_utils {
beam_token_rank += 1;
}
done[batch_index as usize] = done[batch_index as usize]
| hypotheses[batch_index as usize].is_done(
f64::from(next_scores.get(batch_index).max()),
current_length,
);
done[batch_index as usize] |= hypotheses[batch_index as usize].is_done(
f64::from(next_scores.get(batch_index).max()),
current_length,
);
assert_eq!(
next_sentence_beam.len() as i64,
num_beams,
gen_opt.num_beams,
"Beam incomplete"
);
next_batch_beam.append(&mut next_sentence_beam);
@ -2062,8 +2073,8 @@ pub(crate) mod private_generation_utils {
batch_index += 1;
continue;
}
for beam_index in 0..num_beams {
let effective_beam_id = batch_index * num_beams + beam_index;
for beam_index in 0..gen_opt.num_beams {
let effective_beam_id = batch_index * gen_opt.num_beams + beam_index;
let final_score = f64::from(beam_scores.get(effective_beam_id));
let final_tokens = input_ids.get(effective_beam_id);
hypotheses[batch_index as usize].add(final_tokens, final_score);
@ -2071,10 +2082,13 @@ pub(crate) mod private_generation_utils {
batch_index += 1;
}
let (output_batch_size, output_num_return_sequences_per_batch) = if do_sample {
let (output_batch_size, output_num_return_sequences_per_batch) = if gen_opt.do_sample {
(batch_size, 1)
} else {
(batch_size * num_return_sequences, num_return_sequences)
(
batch_size * gen_opt.num_return_sequences,
gen_opt.num_return_sequences,
)
};
let mut sentence_lengths =
@ -2083,7 +2097,7 @@ pub(crate) mod private_generation_utils {
for (hypothesis_index, hypothesis) in hypotheses.iter().enumerate() {
let mut sorted_hypotheses = hypothesis.clone();
&sorted_hypotheses
sorted_hypotheses
.beams
.sort_by_key(|(score, _)| OrderedFloat(*score));
for j in 0..output_num_return_sequences_per_batch {
@ -2101,12 +2115,13 @@ pub(crate) mod private_generation_utils {
let decoded = if i64::from(sentence_lengths.max()) != i64::from(sentence_lengths.min())
{
let sentence_max_length = min(i64::from(sentence_lengths.max()) + 1, max_length);
let sentence_max_length =
min(i64::from(sentence_lengths.max()) + 1, gen_opt.max_length);
let decoded: Tensor = Tensor::ones(
&[output_batch_size, sentence_max_length],
(Int64, input_ids.device()),
) * pad_token_id.unwrap();
for hypothesis_index in 0..best_ids.len() {
) * gen_opt.pad_token_id.unwrap();
for (hypothesis_index, best_id) in best_ids.iter().enumerate() {
let _ = decoded.get(hypothesis_index as i64).index_copy_(
0,
&Tensor::arange1(
@ -2114,14 +2129,14 @@ pub(crate) mod private_generation_utils {
i64::from(sentence_lengths.get(hypothesis_index as i64)),
(Int64, input_ids.device()),
),
&best_ids[hypothesis_index],
&best_id,
);
let sentence_length = i64::from(sentence_lengths.get(hypothesis_index as i64));
if sentence_length < max_length {
if sentence_length < gen_opt.max_length {
let _ = decoded.get(hypothesis_index as i64).index_fill_(
0,
&Tensor::of_slice(&[sentence_length]).to_device(input_ids.device()),
eos_token_ids.as_ref().unwrap()[0],
gen_opt.eos_token_ids.as_ref().unwrap()[0],
);
}
}
@ -2231,7 +2246,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let config = PrivateLanguageGenerator::get_config(self);
let max_length = config.max_length;
let encoding_max_len = if self.is_encoder_decoder() {
1024u64
1024i64
} else {
max_length
};
@ -2377,49 +2392,45 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
(input_ids, attention_mask)
};
let gen_opt = GenerateOptions {
min_length,
max_length,
do_sample,
temperature,
top_k,
top_p,
repetition_penalty,
no_repeat_ngram_size,
pad_token_id,
eos_token_ids,
num_return_sequences,
early_stopping,
num_beams,
length_penalty,
};
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
early_stopping,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
num_return_sequences as i64,
length_penalty,
num_beams as i64,
attention_mask,
gen_opt,
)
} else {
self.generate_no_beam_search(
input_ids,
encoder_outputs,
cur_len,
min_length as i64,
max_length as i64,
do_sample,
temperature,
top_k as i64,
top_p,
repetition_penalty,
no_repeat_ngram_size as i64,
pad_token_id,
eos_token_ids,
effective_batch_size,
attention_mask,
gen_opt,
)
}
});
let num_sequences = *decoded.size().first().unwrap();
let mut output_ids = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences {

View File

@ -684,7 +684,7 @@ impl QuestionAnsweringModel {
vec![],
None,
)
.0
.token_ids
.len()
+ 1
}
@ -700,7 +700,7 @@ impl QuestionAnsweringModel {
vec![],
None,
)
.0
.token_ids
.len(),
};
@ -716,7 +716,7 @@ impl QuestionAnsweringModel {
vec![],
Some(vec![]),
)
.0
.token_ids
.len();
let mut spans: Vec<QaFeature> = vec![];
@ -823,14 +823,7 @@ impl QuestionAnsweringModel {
)
.unwrap();
let (
mut token_ids,
mut segment_ids,
special_tokens_mask,
mut token_offsets,
mut reference_offsets,
mut mask,
) = self.tokenizer.build_input_with_special_tokens(
let mut tokenized_input = self.tokenizer.build_input_with_special_tokens(
truncated_query,
truncated_context,
vec![],
@ -840,25 +833,43 @@ impl QuestionAnsweringModel {
vec![],
None,
);
let mut attention_mask = vec![1; token_ids.len()];
if token_ids.len() < max_seq_length {
token_ids.append(&mut vec![self.pad_idx; max_seq_length - token_ids.len()]);
segment_ids.append(&mut vec![0; max_seq_length - segment_ids.len()]);
let mut attention_mask = vec![1; tokenized_input.token_ids.len()];
if tokenized_input.token_ids.len() < max_seq_length {
tokenized_input.token_ids.append(&mut vec![
self.pad_idx;
max_seq_length
- tokenized_input.token_ids.len()
]);
tokenized_input.segment_ids.append(&mut vec![
0;
max_seq_length
- tokenized_input.segment_ids.len()
]);
attention_mask.append(&mut vec![0; max_seq_length - attention_mask.len()]);
token_offsets.append(&mut vec![None; max_seq_length - token_offsets.len()]);
reference_offsets.append(&mut vec![vec!(); max_seq_length - token_offsets.len()]);
mask.append(&mut vec![Mask::Special; max_seq_length - mask.len()]);
tokenized_input.token_offsets.append(&mut vec![
None;
max_seq_length
- tokenized_input
.token_offsets
.len()
]);
tokenized_input.reference_offsets.append(&mut vec![
vec!();
max_seq_length
- tokenized_input
.token_offsets
.len()
]);
tokenized_input.mask.append(&mut vec![
Mask::Special;
max_seq_length - tokenized_input.mask.len()
]);
}
(
TokenizedInput {
token_ids,
segment_ids,
special_tokens_mask,
overflowing_tokens,
num_truncated_tokens,
token_offsets,
reference_offsets,
mask,
..tokenized_input
},
attention_mask,
)

View File

@ -83,19 +83,19 @@ pub struct SummarizationConfig {
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -103,9 +103,9 @@ pub struct SummarizationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}

View File

@ -438,7 +438,7 @@ impl TokenClassificationOption {
input_embeds,
train,
)
.0
.logits
}
Self::Albert(ref model) => {
model

View File

@ -411,19 +411,19 @@ pub struct TranslationConfig {
/// Merges resource (default: pretrained BART model on CNN-DM)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
pub min_length: i64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
pub max_length: i64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
pub num_beams: i64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
pub top_k: i64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
@ -431,9 +431,9 @@ pub struct TranslationConfig {
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
pub no_repeat_ngram_size: i64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
pub num_return_sequences: i64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
/// Prefix to append translation inputs with

View File

@ -58,23 +58,34 @@ fn electra_masked_lm() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
// Forward pass
let (output, all_hidden_states, all_attentions) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Decode output
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let index_1 = model_output
.prediction_scores
.get(0)
.get(4)
.argmax(0, false);
let index_2 = model_output
.prediction_scores
.get(1)
.get(7)
.argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
assert_eq!(output.size(), &[2, 10, config.vocab_size]);
assert_eq!(
config.num_hidden_layers as usize,
all_hidden_states.unwrap().len()
model_output.prediction_scores.size(),
&[2, 10, config.vocab_size]
);
assert_eq!(
config.num_hidden_layers as usize,
all_attentions.unwrap().len()
model_output.all_hidden_states.unwrap().len()
);
assert_eq!(
config.num_hidden_layers as usize,
model_output.all_attentions.unwrap().len()
);
assert_eq!("thing", word_1); // Outputs "person" : "Looks like one [person] is missing"
assert_eq!("sunny", word_2); // Outputs "pear" : "It was a very nice and [sunny] day"
@ -127,16 +138,20 @@ fn electra_discriminator() -> anyhow::Result<()> {
let input_tensor = Tensor::stack(encoded_input.as_slice(), 0).to(device);
// Forward pass
let (output, _, _) =
let model_output =
no_grad(|| electra_model.forward_t(Some(input_tensor), None, None, None, None, false));
// Validate model predictions
let expected_probabilities = vec![
0.0101, 0.0030, 0.0010, 0.0018, 0.9489, 0.0067, 0.0026, 0.0017, 0.0311, 0.0101,
];
let probabilities = output.iter::<f64>().unwrap().collect::<Vec<f64>>();
let probabilities = model_output
.probabilities
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>();
assert_eq!(output.size(), &[10]);
assert_eq!(model_output.probabilities.size(), &[10]);
for (expected, pred) in probabilities.iter().zip(expected_probabilities) {
assert!((expected - pred).abs() < 1e-4);
}