Implementation of T5ForConditionalGeneration

This commit is contained in:
Guillaume B 2020-07-06 17:43:27 +02:00
parent 857b4bf7a5
commit 7938abc76d
3 changed files with 89 additions and 3 deletions

View File

@ -13,7 +13,10 @@
extern crate failure;
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::t5::{T5Config, T5ConfigResources, T5Model, T5ModelResources, T5VocabResources};
use rust_bert::t5::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources,
T5VocabResources,
};
use rust_bert::Config;
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;
use rust_tokenizers::{Tokenizer, TruncationStrategy};
@ -37,7 +40,7 @@ fn main() -> failure::Fallible<()> {
let tokenizer: T5Tokenizer = T5Tokenizer::from_file(vocab_path.to_str().unwrap(), false);
let config = T5Config::from_file(config_path);
let t5_model = T5Model::new(&vs.root(), &config, false, false);
let t5_model = T5ForConditionalGeneration::new(&vs.root(), &config, false, false);
vs.load(weights_path)?;
// Define input

View File

@ -3,4 +3,7 @@ mod encoder;
mod layer_norm;
mod t5;
pub use t5::{T5Config, T5ConfigResources, T5Model, T5ModelResources, T5VocabResources};
pub use t5::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5Model, T5ModelResources,
T5VocabResources,
};

View File

@ -265,3 +265,83 @@ impl T5Model {
)
}
}
pub struct T5ForConditionalGeneration {
base_model: T5Model,
model_dim: f64,
}
impl T5ForConditionalGeneration {
pub fn new<'p, P>(
p: P,
config: &T5Config,
output_attentions: bool,
output_hidden_states: bool,
) -> T5ForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = T5Model::new(p, config, output_attentions, output_hidden_states);
T5ForConditionalGeneration {
base_model,
model_dim: config.d_model as f64,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<Tensor>,
decoder_input_embeds: Option<Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> (
Tensor,
Tensor,
Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
Option<Vec<Tensor>>,
) {
let (
decoder_outputs,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
) = self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
decoder_attention_mask,
input_embeds,
decoder_input_embeds,
old_layer_states,
train,
);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5));
(
lm_logits,
encoder_hidden_states,
decoder_cache,
all_decoder_hidden_states,
all_decoder_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
)
}
}