mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-11 12:55:34 +03:00
Updated Bart & Marian (clippy warnings)
This commit is contained in:
parent
daa6dba2d2
commit
7e9d6d7e39
1
clippy.toml
Normal file
1
clippy.toml
Normal file
@ -0,0 +1 @@
|
||||
too-many-arguments-threshold = 10
|
@ -73,12 +73,15 @@ fn main() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (decoder_output, encoder_output, _, _, _, _, _) =
|
||||
let model_output =
|
||||
no_grad(|| bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false));
|
||||
|
||||
// Print masked tokens
|
||||
println!("{:?}", encoder_output);
|
||||
println!("{:?}", decoder_output);
|
||||
println!("{:?}", decoder_output.double_value(&[0, 0, 0]));
|
||||
println!("{:?}", model_output.encoder_hidden_state);
|
||||
println!("{:?}", model_output.decoder_hidden_state);
|
||||
println!(
|
||||
"{:?}",
|
||||
model_output.decoder_hidden_state.double_value(&[0, 0, 0])
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = gpt2_model
|
||||
let model_output = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
@ -84,7 +84,12 @@ fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
@ -75,7 +75,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = openai_gpt
|
||||
let model_output = openai_gpt
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
@ -89,7 +89,12 @@ fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
println!("Provided input: {}", input[0]);
|
||||
println!("Next word: {}", next_word);
|
||||
|
@ -13,9 +13,9 @@
|
||||
|
||||
use crate::bart::attention::LayerState;
|
||||
use crate::bart::decoder::BartDecoder;
|
||||
use crate::bart::encoder::BartEncoder;
|
||||
use crate::bart::encoder::{BartEncoder, BartEncoderOutput};
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Borrow;
|
||||
@ -346,15 +346,7 @@ impl BartModel {
|
||||
/// let decoder_attention_mask =
|
||||
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (
|
||||
/// decoder_output,
|
||||
/// encoder_hidden_states,
|
||||
/// decoder_cache,
|
||||
/// all_encoder_hidden_states,
|
||||
/// all_encoder_attentions,
|
||||
/// all_decoder_hidden_states,
|
||||
/// all_decoder_attentions,
|
||||
/// ) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bart_model.forward_t(
|
||||
/// Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
@ -371,19 +363,11 @@ impl BartModel {
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
encoder_output: Option<BartEncoderOutput>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
) -> BartModelOutput {
|
||||
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
|
||||
(decoder_input_ids.unwrap().copy(), None, None)
|
||||
} else {
|
||||
@ -398,43 +382,37 @@ impl BartModel {
|
||||
decoder_attention_mask,
|
||||
)
|
||||
};
|
||||
let (encoder_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
|
||||
match encoder_outputs {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
assert!(
|
||||
input_ids.is_some(),
|
||||
"input_ids must be provided when encoder output is not pre-computed"
|
||||
);
|
||||
self.encoder.forward_t(
|
||||
input_ids.unwrap(),
|
||||
attention_mask,
|
||||
&self.embeddings,
|
||||
train,
|
||||
)
|
||||
}
|
||||
};
|
||||
let encoder_output = match encoder_output {
|
||||
Some(value) => value,
|
||||
None => {
|
||||
assert!(
|
||||
input_ids.is_some(),
|
||||
"input_ids must be provided when encoder output is not pre-computed"
|
||||
);
|
||||
self.encoder
|
||||
.forward_t(input_ids.unwrap(), attention_mask, &self.embeddings, train)
|
||||
}
|
||||
};
|
||||
|
||||
let (decoder_outputs, decoder_cache, all_decoder_hidden_states, all_decoder_attentions) =
|
||||
self.decoder.forward_t(
|
||||
&decoder_input_ids,
|
||||
&encoder_hidden_states,
|
||||
attention_mask,
|
||||
decoder_padding_mask.as_ref(),
|
||||
causal_mask.as_ref(),
|
||||
&self.embeddings,
|
||||
layer_states,
|
||||
train,
|
||||
);
|
||||
(
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache.1,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
let decoder_output = self.decoder.forward_t(
|
||||
&decoder_input_ids,
|
||||
&encoder_output.hidden_state,
|
||||
attention_mask,
|
||||
decoder_padding_mask.as_ref(),
|
||||
causal_mask.as_ref(),
|
||||
&self.embeddings,
|
||||
layer_states,
|
||||
train,
|
||||
);
|
||||
BartModelOutput {
|
||||
decoder_hidden_state: decoder_output.hidden_state,
|
||||
encoder_hidden_state: encoder_output.hidden_state,
|
||||
cache: decoder_output.next_decoder_cache,
|
||||
all_decoder_hidden_states: decoder_output.all_hidden_states,
|
||||
all_decoder_attentions: decoder_output.all_attentions,
|
||||
all_encoder_hidden_states: encoder_output.all_hidden_states,
|
||||
all_encoder_attentions: encoder_output.all_attentions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -525,9 +503,7 @@ impl BartForConditionalGeneration {
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bart_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
@ -542,58 +518,41 @@ impl BartForConditionalGeneration {
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
encoder_output: Option<BartEncoderOutput>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.base_model.forward_t(
|
||||
) -> BartModelOutput {
|
||||
let base_model_output = self.base_model.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
encoder_output,
|
||||
decoder_attention_mask,
|
||||
old_layer_states,
|
||||
train,
|
||||
);
|
||||
|
||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
(
|
||||
lm_logits,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
let lm_logits = base_model_output
|
||||
.decoder_hidden_state
|
||||
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
BartModelOutput {
|
||||
decoder_hidden_state: lm_logits,
|
||||
..base_model_output
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
);
|
||||
encoder_hidden_states
|
||||
self.base_model
|
||||
.encoder
|
||||
.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
)
|
||||
.hidden_state
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,9 +699,7 @@ impl BartForSequenceClassification {
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bart_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
@ -757,60 +714,51 @@ impl BartForSequenceClassification {
|
||||
&self,
|
||||
input_ids: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
encoder_output: Option<BartEncoderOutput>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
Tensor,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
_,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.base_model.forward_t(
|
||||
) -> BartModelOutput {
|
||||
let base_model_output = self.base_model.forward_t(
|
||||
Some(input_ids),
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
encoder_outputs,
|
||||
encoder_output,
|
||||
decoder_attention_mask,
|
||||
None,
|
||||
train,
|
||||
);
|
||||
let eos_mask = input_ids.eq(self.eos_token_id);
|
||||
let reshape = eos_mask.sum1(&[1], true, Int64);
|
||||
let sentence_representation = decoder_outputs
|
||||
let sentence_representation = base_model_output
|
||||
.decoder_hidden_state
|
||||
.permute(&[2, 0, 1])
|
||||
.masked_select(&eos_mask)
|
||||
.view((-1, reshape.size()[0] * reshape.int64_value(&[0, 0])))
|
||||
.transpose(0, 1)
|
||||
.view((
|
||||
decoder_outputs.size()[0],
|
||||
base_model_output.decoder_hidden_state.size()[0],
|
||||
-1,
|
||||
*decoder_outputs.size().last().unwrap(),
|
||||
*base_model_output
|
||||
.decoder_hidden_state
|
||||
.size()
|
||||
.last()
|
||||
.unwrap(),
|
||||
))
|
||||
.select(1, -1);
|
||||
|
||||
let logits = self
|
||||
.classification_head
|
||||
.forward_t(&sentence_representation, train);
|
||||
(
|
||||
logits,
|
||||
encoder_hidden_states,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
)
|
||||
BartModelOutput {
|
||||
decoder_hidden_state: logits,
|
||||
encoder_hidden_state: base_model_output.encoder_hidden_state,
|
||||
cache: None,
|
||||
all_decoder_hidden_states: base_model_output.all_decoder_hidden_states,
|
||||
all_decoder_attentions: base_model_output.all_decoder_attentions,
|
||||
all_encoder_hidden_states: base_model_output.all_encoder_hidden_states,
|
||||
all_encoder_attentions: base_model_output.all_encoder_attentions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -860,9 +808,7 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (decoder_output, encoder_hidden_states, cache,
|
||||
/// all_encoder_hidden_states, all_encoder_attentions,
|
||||
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// bart_model
|
||||
/// .forward_t(Some(&input_tensor),
|
||||
/// Some(&encoder_attention_mask),
|
||||
@ -884,22 +830,17 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||
) -> Result<LMModelOutput, &'static str> {
|
||||
let base_model_output = match cache {
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
Some(BartEncoderOutput {
|
||||
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
}),
|
||||
None,
|
||||
cached_layer_states,
|
||||
train,
|
||||
@ -909,21 +850,37 @@ impl LMHeadModel for BartForConditionalGeneration {
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
Some(BartEncoderOutput {
|
||||
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
train,
|
||||
),
|
||||
_ => Err("Cache not compatible with BART Model")?,
|
||||
_ => return Err("Cache not compatible with BART Model"),
|
||||
};
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
Ok((
|
||||
let lm_logits = base_model_output
|
||||
.decoder_hidden_state
|
||||
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
Ok(LMModelOutput {
|
||||
lm_logits,
|
||||
Some(encoder_hidden_states),
|
||||
Cache::BARTCache(new_cache),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
|
||||
cache: Cache::BARTCache(base_model_output.cache),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BartModelOutput {
|
||||
pub decoder_hidden_state: Tensor,
|
||||
pub encoder_hidden_state: Tensor,
|
||||
pub cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
pub all_decoder_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_decoder_attentions: Option<Vec<Tensor>>,
|
||||
pub all_encoder_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_encoder_attentions: Option<Vec<Tensor>>,
|
||||
}
|
||||
|
@ -288,15 +288,7 @@ impl BartDecoder {
|
||||
embeddings: &nn::Embedding,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
train: bool,
|
||||
) -> (
|
||||
Tensor,
|
||||
(
|
||||
Option<Tensor>,
|
||||
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
),
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
) -> BartDecoderOutput {
|
||||
let encoder_padding_mask = match encoder_padding_mask {
|
||||
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
|
||||
None => None,
|
||||
@ -342,45 +334,48 @@ impl BartDecoder {
|
||||
};
|
||||
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
let mut layers = self.layers.iter().enumerate();
|
||||
|
||||
loop {
|
||||
match layers.next() {
|
||||
Some((layer_idx, layer)) => {
|
||||
let layer_state = match &next_decoder_cache {
|
||||
Some(values) => values[layer_idx].to_owned(),
|
||||
None => (None, None),
|
||||
};
|
||||
let temp = layer.forward_t(
|
||||
&hidden_state,
|
||||
&encoder_hidden_states,
|
||||
encoder_padding_mask.as_ref(),
|
||||
decoder_causal_mask,
|
||||
decoder_padding_mask,
|
||||
layer_state,
|
||||
train,
|
||||
);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
if let Some(value) = &mut next_decoder_cache {
|
||||
value[layer_idx] = temp.2
|
||||
};
|
||||
}
|
||||
None => break,
|
||||
for (layer_idx, layer) in self.layers.iter().enumerate() {
|
||||
let layer_state = match &next_decoder_cache {
|
||||
Some(values) => values[layer_idx].to_owned(),
|
||||
None => (None, None),
|
||||
};
|
||||
let temp = layer.forward_t(
|
||||
&hidden_state,
|
||||
&encoder_hidden_states,
|
||||
encoder_padding_mask.as_ref(),
|
||||
decoder_causal_mask,
|
||||
decoder_padding_mask,
|
||||
layer_state,
|
||||
train,
|
||||
);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
if let Some(value) = &mut next_decoder_cache {
|
||||
value[layer_idx] = temp.2
|
||||
};
|
||||
}
|
||||
|
||||
(
|
||||
hidden_state.transpose(0, 1),
|
||||
(encoder_padding_mask, next_decoder_cache),
|
||||
BartDecoderOutput {
|
||||
hidden_state: hidden_state.transpose(0, 1),
|
||||
encoder_padding_mask,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BartDecoderOutput {
|
||||
pub hidden_state: Tensor,
|
||||
pub encoder_padding_mask: Option<Tensor>,
|
||||
pub next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
pub all_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_attentions: Option<Vec<Tensor>>,
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ impl BartEncoder {
|
||||
attention_mask: Option<&Tensor>,
|
||||
embeddings: &nn::Embedding,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
|
||||
) -> BartEncoderOutput {
|
||||
let attention_mask = match attention_mask {
|
||||
Some(mask) => Some(mask.eq(0).to_kind(Bool)),
|
||||
None => None,
|
||||
@ -260,33 +260,34 @@ impl BartEncoder {
|
||||
|
||||
let mut hidden_state = x.copy();
|
||||
let mut attention_weights: Option<Tensor>;
|
||||
let mut layers = self.layers.iter();
|
||||
|
||||
loop {
|
||||
match layers.next() {
|
||||
Some(layer) => {
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
for layer in &self.layers {
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
|
||||
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
None => break,
|
||||
let temp = layer.forward_t(&hidden_state, attention_mask.as_ref(), train);
|
||||
hidden_state = temp.0;
|
||||
attention_weights = temp.1;
|
||||
if let Some(attentions) = all_attentions.borrow_mut() {
|
||||
attentions.push(attention_weights.as_ref().unwrap().copy());
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
|
||||
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
|
||||
};
|
||||
|
||||
(
|
||||
hidden_state.transpose(0, 1),
|
||||
BartEncoderOutput {
|
||||
hidden_state: hidden_state.transpose(0, 1),
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BartEncoderOutput {
|
||||
pub hidden_state: Tensor,
|
||||
pub all_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_attentions: Option<Vec<Tensor>>,
|
||||
}
|
||||
|
@ -69,3 +69,5 @@ pub use bart_model::{
|
||||
BartForSequenceClassification, BartMergesResources, BartModel, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
|
||||
pub(crate) use encoder::BartEncoderOutput;
|
||||
|
@ -86,7 +86,7 @@ impl BertSelfAttention {
|
||||
fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view((bs, -1, &self.num_attention_heads * dim_per_head))
|
||||
.view((bs, -1, self.num_attention_heads * dim_per_head))
|
||||
}
|
||||
|
||||
pub fn forward_t(
|
||||
|
@ -15,7 +15,7 @@
|
||||
use crate::common::dropout::Dropout;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::gpt2::transformer::Block;
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
|
||||
use crate::Config;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
@ -608,7 +608,7 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(
|
||||
/// &Some(input_tensor),
|
||||
@ -635,16 +635,7 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
_encoder_outputs: Option<&Tensor>,
|
||||
_decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
) -> Result<LMModelOutput, &'static str> {
|
||||
let (output, past, all_hidden_states, all_attentions) = match layer_past {
|
||||
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(
|
||||
input_ids,
|
||||
@ -668,12 +659,12 @@ impl LMHeadModel for GPT2LMHeadModel {
|
||||
}?;
|
||||
|
||||
let lm_logits = output.apply(&self.lm_head);
|
||||
Ok((
|
||||
Ok(LMModelOutput {
|
||||
lm_logits,
|
||||
None,
|
||||
Cache::GPT2Cache(past),
|
||||
encoder_hidden_state: None,
|
||||
cache: Cache::GPT2Cache(past),
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -11,8 +11,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::bart::{BartConfig, BartModel, LayerState};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::bart::{BartConfig, BartEncoderOutput, BartModel, LayerState};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
|
||||
use std::borrow::Borrow;
|
||||
use tch::nn::Init;
|
||||
use tch::{nn, Tensor};
|
||||
@ -349,7 +349,7 @@ impl MarianForConditionalGeneration {
|
||||
&self,
|
||||
input_ids: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
|
||||
encoder_outputs: Option<BartEncoderOutput>,
|
||||
decoder_input_ids: Option<&Tensor>,
|
||||
decoder_attention_mask: Option<&Tensor>,
|
||||
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
|
||||
@ -363,15 +363,7 @@ impl MarianForConditionalGeneration {
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
) {
|
||||
let (
|
||||
decoder_outputs,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
) = self.base_model.forward_t(
|
||||
let base_model_output = self.base_model.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
decoder_input_ids,
|
||||
@ -381,25 +373,31 @@ impl MarianForConditionalGeneration {
|
||||
train,
|
||||
);
|
||||
|
||||
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
let lm_logits = base_model_output
|
||||
.decoder_hidden_state
|
||||
.linear::<Tensor>(&self.base_model.embeddings.ws, None);
|
||||
(
|
||||
lm_logits,
|
||||
encoder_hidden_states,
|
||||
decoder_cache,
|
||||
all_decoder_hidden_states,
|
||||
all_decoder_attentions,
|
||||
all_encoder_hidden_states,
|
||||
all_encoder_attentions,
|
||||
base_model_output.encoder_hidden_state,
|
||||
base_model_output.cache,
|
||||
base_model_output.all_decoder_hidden_states,
|
||||
base_model_output.all_decoder_attentions,
|
||||
base_model_output.all_encoder_hidden_states,
|
||||
base_model_output.all_encoder_attentions,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
|
||||
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
);
|
||||
let encoder_hidden_states = self
|
||||
.base_model
|
||||
.encoder
|
||||
.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
&self.base_model.embeddings,
|
||||
false,
|
||||
)
|
||||
.hidden_state;
|
||||
encoder_hidden_states
|
||||
}
|
||||
}
|
||||
@ -483,22 +481,17 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||
) -> Result<LMModelOutput, &'static str> {
|
||||
let base_model_output = match cache {
|
||||
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
Some(BartEncoderOutput {
|
||||
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
}),
|
||||
None,
|
||||
cached_layer_states,
|
||||
train,
|
||||
@ -507,7 +500,11 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
input_ids.as_ref(),
|
||||
attention_mask.as_ref(),
|
||||
decoder_input_ids.as_ref(),
|
||||
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
|
||||
Some(BartEncoderOutput {
|
||||
hidden_state: encoder_outputs.as_ref().unwrap().copy(),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
}),
|
||||
None,
|
||||
None,
|
||||
train,
|
||||
@ -515,14 +512,16 @@ impl LMHeadModel for MarianForConditionalGeneration {
|
||||
_ => Err("Cache not compatible with Marian Model")?,
|
||||
};
|
||||
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
||||
let lm_logits = base_model_output
|
||||
.decoder_hidden_state
|
||||
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
||||
+ &self.final_logits_bias;
|
||||
Ok((
|
||||
Ok(LMModelOutput {
|
||||
lm_logits,
|
||||
Some(encoder_hidden_states),
|
||||
Cache::BARTCache(new_cache),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
encoder_hidden_state: Some(base_model_output.encoder_hidden_state),
|
||||
cache: Cache::BARTCache(base_model_output.cache),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ use crate::common::dropout::Dropout;
|
||||
use crate::common::linear::{linear_no_bias, LinearNoBias};
|
||||
use crate::gpt2::Gpt2Config;
|
||||
use crate::openai_gpt::transformer::Block;
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use tch::kind::Kind::Int64;
|
||||
use tch::nn::embedding;
|
||||
@ -388,7 +388,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, _, hidden_states, attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// gpt_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// Cache::None,
|
||||
@ -412,16 +412,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
||||
_encoder_outputs: Option<&Tensor>,
|
||||
_decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
) -> Result<LMModelOutput, &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = self.transformer.forward_t(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
@ -432,12 +423,12 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
|
||||
)?;
|
||||
|
||||
let lm_logits = output.apply(&self.lm_head);
|
||||
Ok((
|
||||
Ok(LMModelOutput {
|
||||
lm_logits,
|
||||
None,
|
||||
Cache::None,
|
||||
encoder_hidden_state: None,
|
||||
cache: Cache::None,
|
||||
all_hidden_states,
|
||||
all_attentions,
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1648,8 +1648,8 @@ pub(crate) mod private_generation_utils {
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
outputs = temp.0;
|
||||
past = temp.2;
|
||||
outputs = temp.lm_logits;
|
||||
past = temp.cache;
|
||||
|
||||
let mut next_token_logits = outputs.select(1, -1);
|
||||
// Reduce probability for repeated inputs
|
||||
@ -1854,8 +1854,8 @@ pub(crate) mod private_generation_utils {
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
outputs = temp.0;
|
||||
past = temp.2;
|
||||
outputs = temp.lm_logits;
|
||||
past = temp.cache;
|
||||
|
||||
let mut next_token_logits = outputs.select(1, -1);
|
||||
|
||||
@ -2575,7 +2575,7 @@ pub trait LMHeadModel {
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device))
|
||||
/// .expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, encoder_output, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// let model_output = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(
|
||||
/// &Some(input_tensor),
|
||||
@ -2602,14 +2602,13 @@ pub trait LMHeadModel {
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
>;
|
||||
) -> Result<LMModelOutput, &'static str>;
|
||||
}
|
||||
|
||||
pub struct LMModelOutput {
|
||||
pub lm_logits: Tensor,
|
||||
pub encoder_hidden_state: Option<Tensor>,
|
||||
pub cache: Cache,
|
||||
pub all_hidden_states: Option<Vec<Tensor>>,
|
||||
pub all_attentions: Option<Vec<Tensor>>,
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ impl SequenceClassificationOption {
|
||||
None,
|
||||
train,
|
||||
)
|
||||
.0
|
||||
.decoder_hidden_state
|
||||
}
|
||||
Self::Bert(ref model) => {
|
||||
model
|
||||
|
@ -332,7 +332,7 @@ impl ZeroShotClassificationOption {
|
||||
None,
|
||||
train,
|
||||
)
|
||||
.0
|
||||
.decoder_hidden_state
|
||||
}
|
||||
Self::Bert(ref model) => {
|
||||
model
|
||||
|
@ -9,7 +9,7 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel};
|
||||
use crate::pipelines::generation::{Cache, LMHeadModel, LMModelOutput};
|
||||
use crate::t5::attention::LayerState;
|
||||
use crate::t5::encoder::T5Stack;
|
||||
use crate::Config;
|
||||
@ -684,16 +684,7 @@ impl LMHeadModel for T5ForConditionalGeneration {
|
||||
encoder_outputs: Option<&Tensor>,
|
||||
decoder_input_ids: &Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<
|
||||
(
|
||||
Tensor,
|
||||
Option<Tensor>,
|
||||
Cache,
|
||||
Option<Vec<Tensor>>,
|
||||
Option<Vec<Tensor>>,
|
||||
),
|
||||
&'static str,
|
||||
> {
|
||||
) -> Result<LMModelOutput, &'static str> {
|
||||
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
|
||||
Cache::T5Cache(cached_layer_states) => self.base_model.forward_t(
|
||||
input_ids.as_ref(),
|
||||
@ -723,12 +714,12 @@ impl LMHeadModel for T5ForConditionalGeneration {
|
||||
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None)
|
||||
* (self.model_dim.powf(-0.5));
|
||||
|
||||
Ok((
|
||||
Ok(LMModelOutput {
|
||||
lm_logits,
|
||||
Some(encoder_hidden_states),
|
||||
Cache::T5Cache(new_cache),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
encoder_hidden_state: Some(encoder_hidden_states),
|
||||
cache: Cache::T5Cache(new_cache),
|
||||
all_hidden_states: None,
|
||||
all_attentions: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -62,12 +62,12 @@ fn bart_lm_model() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, encoder_outputs, _, _, _, _, _) =
|
||||
let model_output =
|
||||
bart_model.forward_t(Some(&input_tensor), None, None, None, None, None, false);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 6, 1024));
|
||||
assert_eq!(encoder_outputs.size(), vec!(1, 6, 1024));
|
||||
assert!((output.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
|
||||
assert_eq!(model_output.decoder_hidden_state.size(), vec!(1, 6, 1024));
|
||||
assert_eq!(model_output.encoder_hidden_state.size(), vec!(1, 6, 1024));
|
||||
assert!((model_output.decoder_hidden_state.double_value(&[0, 0, 0]) - 0.7877).abs() < 1e-4);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, past, _, _) = gpt2_model
|
||||
let model_output = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
@ -75,11 +75,16 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 11, 50257));
|
||||
match past {
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
|
||||
match model_output.cache {
|
||||
Cache::GPT2Cache(past) => {
|
||||
assert!(past.is_some());
|
||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||
@ -91,7 +96,13 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||
_ => panic!("Wrong cache returned for GPT2"),
|
||||
}
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4
|
||||
(model_output.lm_logits.double_value(&[
|
||||
0,
|
||||
model_output.lm_logits.size()[1] - 1,
|
||||
next_word_id
|
||||
]) - (-48.7065))
|
||||
.abs()
|
||||
< 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 14104i64);
|
||||
assert_eq!(next_word, String::from(" twelve"));
|
||||
|
@ -62,7 +62,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, past, _, _) = gpt2_model
|
||||
let model_output = gpt2_model
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
@ -76,11 +76,16 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 4, 50257));
|
||||
match past {
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
|
||||
match model_output.cache {
|
||||
Cache::GPT2Cache(past) => {
|
||||
assert!(past.is_some());
|
||||
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
|
||||
@ -92,7 +97,13 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
|
||||
_ => panic!("Wrong cache returned for GPT2"),
|
||||
}
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4
|
||||
(model_output.lm_logits.double_value(&[
|
||||
0,
|
||||
model_output.lm_logits.size()[1] - 1,
|
||||
next_word_id
|
||||
]) - (-69.4948))
|
||||
.abs()
|
||||
< 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 1936i64);
|
||||
assert_eq!(next_word, String::from(" five"));
|
||||
|
@ -64,7 +64,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
||||
let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
|
||||
|
||||
// Forward pass
|
||||
let (output, _, _, _, _) = openai_gpt
|
||||
let model_output = openai_gpt
|
||||
.forward_t(
|
||||
&Some(input_tensor),
|
||||
Cache::None,
|
||||
@ -78,12 +78,23 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let next_word_id = output.get(0).get(-1).argmax(-1, true).int64_value(&[0]);
|
||||
let next_word_id = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
|
||||
assert_eq!(output.size(), vec!(1, 6, 40478));
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
|
||||
assert!(
|
||||
(output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (9.1056)).abs() < 1e-4
|
||||
(model_output.lm_logits.double_value(&[
|
||||
0,
|
||||
model_output.lm_logits.size()[1] - 1,
|
||||
next_word_id
|
||||
]) - (9.1056))
|
||||
.abs()
|
||||
< 1e-4
|
||||
);
|
||||
assert_eq!(next_word_id, 580i64);
|
||||
assert_eq!(next_word, String::from("be"));
|
||||
|
Loading…
Reference in New Issue
Block a user