Merge pull request #49 from guillaume-be/non_mutable_generation

Updated generation pipeline to be non mutable, reworked BART caching mechanism
This commit is contained in:
guillaume-be 2020-06-06 10:54:21 +02:00 committed by GitHub
commit a459fa1b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 431 additions and 355 deletions

View File

@ -35,7 +35,7 @@ fn main() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = BartConfig::from_file(config_path);
let mut bart_model = BartModel::new(&vs.root(), &config, false);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
@ -85,6 +85,7 @@ about exoplanets like K2-18b."];
None,
None,
None,
None,
false)
});

View File

@ -26,7 +26,7 @@ fn main() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";

View File

@ -15,7 +15,7 @@ extern crate failure;
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, Gpt2Tokenizer};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::LMHeadModel;
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::Config;
@ -36,7 +36,7 @@ fn main() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let mut gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -58,7 +58,7 @@ fn main() -> failure::Fallible<()> {
// Forward pass
let (output, _, _, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
&None,
Cache::None,
&None,
&None,
&None,

View File

@ -16,7 +16,7 @@ use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::pipelines::generation::LMHeadModel;
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::Config;
@ -37,7 +37,7 @@ fn main() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let config = Gpt2Config::from_file(config_path);
let mut openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -59,7 +59,7 @@ fn main() -> failure::Fallible<()> {
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
&None,
Cache::None,
&None,
&None,
&None,

View File

@ -16,7 +16,7 @@ use rust_bert::pipelines::summarization::SummarizationModel;
fn main() -> failure::Fallible<()> {
let mut summarization_model = SummarizationModel::new(Default::default())?;
let summarization_model = SummarizationModel::new(Default::default())?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
@ -41,9 +41,10 @@ telescope — scheduled for launch in 2021 — and the European Space Agency's 2
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{:?}", sentence);
}
};
Ok(())
}

View File

@ -18,9 +18,8 @@ use tch::Device;
fn main() -> failure::Fallible<()> {
let translation_config = TranslationConfig::new(Language::EnglishToGerman, Device::cuda_if_available());
let mut model = TranslationModel::new(translation_config)?;
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";

View File

@ -20,31 +20,35 @@ use tch::kind::Kind::Float;
/// Stores the cached value of key, value and key padding mask to avoid recalculation (e.g. at each generation step)
pub struct LayerState {
/// Cached keys
pub prev_key: Option<Tensor>,
pub prev_key: Tensor,
/// Cached values
pub prev_value: Option<Tensor>,
pub prev_value: Tensor,
/// Cached keys padding mask
pub prev_key_padding_mask: Option<Tensor>,
}
impl Clone for LayerState {
fn clone(&self) -> Self {
let prev_key_padding_mask = match &self.prev_key_padding_mask {
Some(key_padding_mask) => Some(key_padding_mask.copy()),
None => None
};
LayerState {
prev_key: self.prev_key.copy(),
prev_value: self.prev_value.copy(),
prev_key_padding_mask,
}
}
}
impl LayerState {
pub(crate) fn reorder_cache(&mut self, new_indices: &Tensor) {
if self.prev_key.is_some() {
self.prev_key = Some(self.prev_key.as_ref().unwrap().index_select(0, new_indices));
}
if self.prev_value.is_some() {
self.prev_value = Some(self.prev_value.as_ref().unwrap().index_select(0, new_indices));
}
self.prev_key = self.prev_key.index_select(0, new_indices);
self.prev_value = self.prev_value.index_select(0, new_indices);
if self.prev_key_padding_mask.is_some() {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}
pub(crate) fn reset_cache(&mut self) {
self.prev_key = None;
self.prev_value = None;
self.prev_key_padding_mask = None;
}
}
@ -56,11 +60,11 @@ pub struct SelfAttention {
scaling: f64,
encoder_decoder_attention: bool,
output_attentions: bool,
pub(crate) prev_state: Option<LayerState>,
k_proj: nn::Linear,
v_proj: nn::Linear,
q_proj: nn::Linear,
out_proj: nn::Linear,
store_cache: bool,
}
impl SelfAttention {
@ -74,11 +78,6 @@ impl SelfAttention {
let head_dim = embed_dim / num_heads;
let scaling = (head_dim as f64).powf(-0.5);
let dropout = Dropout::new(dropout);
let prev_state = if store_cache {
Some(LayerState { prev_key: None, prev_value: None, prev_key_padding_mask: None })
} else {
None
};
SelfAttention {
num_heads,
@ -87,11 +86,11 @@ impl SelfAttention {
scaling,
encoder_decoder_attention,
output_attentions,
prev_state,
k_proj,
v_proj,
q_proj,
out_proj,
store_cache,
}
}
@ -99,22 +98,17 @@ impl SelfAttention {
x.contiguous().view((dim_0, bs * self.num_heads, self.head_dim)).transpose(0, 1)
}
pub fn forward_t(&mut self, query: &Tensor,
pub fn forward_t(&self, query: &Tensor,
key: Option<&Tensor>,
key_padding_mask: Option<&Tensor>,
attention_mask: Option<&Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
mut layer_state: Option<LayerState>,
train: bool) -> (Tensor, Option<Tensor>, Option<LayerState>) {
let query_size = query.size();
let (target_sequence_length, bs) = (query_size[0], query_size[1]);
let q: Tensor = self.flatten(query.as_ref().apply(&self.q_proj) * self.scaling, target_sequence_length, bs);
let key = match &self.prev_state {
Some(prev_state) => {
if prev_state.prev_key.is_some() & self.encoder_decoder_attention {
None
} else {
key
}
}
let key = match &layer_state {
Some(_) => { if self.encoder_decoder_attention { None } else { key } }
None => key
};
@ -133,19 +127,8 @@ impl SelfAttention {
)
};
let (k, v, key_padding_mask) = self.use_saved_state(k, v, key_padding_mask, bs);
let (k, v, key_padding_mask) = self.use_saved_state(&layer_state, k, v, key_padding_mask, bs);
self.prev_state = match &self.prev_state {
Some(_) => Some(LayerState {
prev_key: Some(k.view((bs, self.num_heads, -1, self.head_dim))),
prev_value: Some(v.view((bs, self.num_heads, -1, self.head_dim))),
prev_key_padding_mask: match key_padding_mask.as_ref() {
Some(tensor) => Some(tensor.copy()),
None => None
},
}),
None => None
};
let source_sequence_length = k.size()[1];
let attention_weights = q.bmm(&k.transpose(1, 2));
let attention_weights = match attention_mask {
@ -179,35 +162,51 @@ impl SelfAttention {
Some(attention_weights.view((bs, self.num_heads, target_sequence_length, source_sequence_length)))
} else { None };
(output, attention_weights)
if self.store_cache {
if layer_state.is_some() {
layer_state.as_mut().unwrap().prev_key = k.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_value = v.view((bs, self.num_heads, -1, self.head_dim));
layer_state.as_mut().unwrap().prev_key_padding_mask = match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
};
} else {
layer_state = Some(LayerState {
prev_key: k.view((bs, self.num_heads, -1, self.head_dim)),
prev_value: v.view((bs, self.num_heads, -1, self.head_dim)),
prev_key_padding_mask: match key_padding_mask {
Some(tensor) => Some(tensor),
None => None
},
})
};
};
(output, attention_weights, layer_state)
}
fn use_saved_state(&self, k: Option<Tensor>, v: Option<Tensor>, key_padding_mask: Option<&Tensor>, bs: i64)
fn use_saved_state(&self,
layer_state: &Option<LayerState>,
k: Option<Tensor>,
v: Option<Tensor>,
key_padding_mask: Option<&Tensor>,
bs: i64)
-> (Tensor, Tensor, Option<Tensor>) {
match &self.prev_state {
match &layer_state {
Some(prev_state) => {
let k = match &prev_state.prev_key {
Some(prev_key) => {
let prev_key = prev_key.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
}
}
None => k.unwrap()
let prev_key = prev_state.prev_key.view((bs * self.num_heads, -1, self.head_dim));
let prev_value = prev_state.prev_value.view((bs * self.num_heads, -1, self.head_dim));
let k = if self.encoder_decoder_attention {
prev_key
} else {
Tensor::cat(&[prev_key, k.unwrap()], 1)
};
let v = match &prev_state.prev_value {
Some(prev_value) => {
let prev_value = prev_value.view((bs * self.num_heads, -1, self.head_dim));
if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
}
}
None => v.unwrap()
let v = if self.encoder_decoder_attention {
prev_value
} else {
Tensor::cat(&[prev_value, v.unwrap()], 1)
};
let key_padding_mask = self.use_saved_key_padding_mask(key_padding_mask,
&prev_state.prev_key_padding_mask,
bs,

View File

@ -22,7 +22,7 @@ use tch::nn::{embedding, EmbeddingConfig};
use crate::bart::attention::LayerState;
use std::borrow::BorrowMut;
use crate::common::dropout::Dropout;
use crate::pipelines::generation::LMHeadModel;
use crate::pipelines::generation::{Cache, LMHeadModel};
/// # BART Pretrained model weight files
pub struct BartModelResources;
@ -222,8 +222,6 @@ impl BartModel {
BartModel { encoder, decoder, generation_mode, pad_token_id, embeddings }
}
pub(crate) fn get_decoder(&mut self) -> &mut BartDecoder { &mut self.decoder }
/// Forward pass through the model
///
/// # Arguments
@ -261,7 +259,7 @@ impl BartModel {
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let mut bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
///# let bart_model: BartModel = BartModel::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
@ -277,19 +275,21 @@ impl BartModel {
/// Some(&target_tensor),
/// None,
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&mut self,
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_attention_mask: Option<&Tensor>,
layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool) ->
(Tensor, Tensor, (Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
(Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (decoder_input_ids, decoder_padding_mask, causal_mask) = if self.generation_mode {
@ -318,20 +318,13 @@ impl BartModel {
decoder_padding_mask.as_ref(),
causal_mask.as_ref(),
&self.embeddings,
layer_states,
train);
(decoder_outputs, encoder_hidden_states, decoder_cache,
(decoder_outputs, encoder_hidden_states, decoder_cache.1,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
}
/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
for layer in self.get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reset_cache();
layer.get_encoder_attention().prev_state.as_mut().unwrap().reset_cache();
};
}
}
/// # BART Model for conditional generation
@ -407,14 +400,14 @@ impl BartForConditionalGeneration {
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let mut bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
@ -423,44 +416,40 @@ impl BartForConditionalGeneration {
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&mut self,
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor,
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, _,
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.borrow_mut().base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, train);
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states,
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
}
pub(crate) fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }
pub fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
encoder_hidden_states
}
/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
self.get_base_model().reset_cache()
}
}
pub struct BartClassificationHead {
@ -569,14 +558,14 @@ impl BartForSequenceClassification {
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let mut bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
@ -585,6 +574,7 @@ impl BartForSequenceClassification {
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
@ -603,7 +593,7 @@ impl BartForSequenceClassification {
let (decoder_outputs, encoder_hidden_states, _,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, train);
self.borrow_mut().base_model.forward_t(Some(input_ids), attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, None, train);
let eos_mask = input_ids.eq(self.eos_token_id);
let sentence_representation = decoder_outputs
@ -617,12 +607,6 @@ impl BartForSequenceClassification {
all_encoder_hidden_states, all_encoder_attentions)
}
pub(crate) fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }
/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
self.get_base_model().reset_cache()
}
}
impl LMHeadModel for BartForConditionalGeneration {
@ -657,57 +641,65 @@ impl LMHeadModel for BartForConditionalGeneration {
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::bart::{BartForConditionalGeneration, BartConfig};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///# let config = BartConfig::from_file(config_path);
///# let bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (output, encoder_hidden_states, _, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// &None,
/// false).unwrap()
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
/// ```
///
fn forward_t(&mut self,
fn forward_t(&self,
input_ids: &Option<Tensor>,
_layer_past: &Option<Vec<Tensor>>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (decoder_output, encoder_hidden_states, _, _, _, _, _) = self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
train);
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with BART Model")?
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None);
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
}
}

View File

@ -96,20 +96,18 @@ impl DecoderLayer {
}
}
pub fn get_self_attention(&mut self) -> &mut SelfAttention { &mut self.self_attention }
pub fn get_encoder_attention(&mut self) -> &mut SelfAttention { &mut self.encoder_attention }
pub fn forward_t(&mut self,
pub fn forward_t(&self,
x: &Tensor,
encoder_hidden_states: &Tensor,
encoder_attn_mask: Option<&Tensor>,
causal_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (output, attention_weights) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, train);
layer_states: (Option<LayerState>, Option<LayerState>),
train: bool) -> (Tensor, Option<Tensor>, (Option<LayerState>, Option<LayerState>)) {
let (output, attention_weights, new_self_layer_states) = self.self_attention.forward_t(x, Some(x), decoder_padding_mask, causal_mask, layer_states.0, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
let (output1, _) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, train);
let (output1, _, new_encoder_layer_states) = self.encoder_attention.forward_t(&output, Some(encoder_hidden_states), encoder_attn_mask, None, layer_states.1, train);
let output1: Tensor = output1.apply_t(&self.dropout, train) + output;
let output1 = output1.apply(&self.encoder_attention_layer_norm);
let output2 = (self.activation)(&output1.apply(&self.fc1));
@ -118,7 +116,7 @@ impl DecoderLayer {
.apply(&self.fc2)
.apply_t(&self.dropout, train);
let output2: Tensor = output2 + output1;
(output2.apply(&self.final_layer_norm), attention_weights)
(output2.apply(&self.final_layer_norm), attention_weights, (new_self_layer_states, new_encoder_layer_states))
}
}
@ -138,7 +136,7 @@ impl BartDecoder {
pub fn new(p: nn::Path, config: &BartConfig, generation_mode: bool) -> BartDecoder {
let output_past = match config.output_past {
Some(value) => value,
None => false
None => true
};
let output_attentions = match config.output_attentions {
Some(value) => value,
@ -157,7 +155,7 @@ impl BartDecoder {
None => false
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value {(config.d_model as f64).sqrt()} else { 1.0 },
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
};
@ -203,22 +201,21 @@ impl BartDecoder {
output_hidden_states,
output_past,
generation_mode,
scale_embedding
scale_embedding,
}
}
pub fn get_layers(&mut self) -> &mut Vec<DecoderLayer> { &mut self.layers }
pub fn forward_t(&mut self,
pub fn forward_t(&self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
encoder_padding_mask: Option<&Tensor>,
decoder_padding_mask: Option<&Tensor>,
decoder_causal_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor,
(Option<Tensor>, Option<Vec<(&LayerState, &LayerState)>>),
(Option<Tensor>, Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
Option<Vec<Tensor>>,
Option<Vec<Tensor>>) {
let encoder_padding_mask = match encoder_padding_mask {
@ -227,36 +224,42 @@ impl BartDecoder {
};
let positions = self.embed_positions.forward(input_ids, self.generation_mode);
let (input_ids, positions) = if self.generation_mode {
let x: Tensor = if self.generation_mode {
let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1];
(input_ids.slice(1, end_inputs - 1, end_inputs, 1),
positions.slice(1, end_positions - 1, end_positions, 1))
input_ids.narrow(1, end_inputs - 1, 1).apply(embeddings) * self.scale_embedding + positions.narrow(1, end_positions - 1, 1)
} else {
(input_ids.copy(), positions)
input_ids.apply(embeddings) * self.scale_embedding + positions
};
let x: Tensor = input_ids.as_ref().apply(embeddings) * self.scale_embedding + positions;
let x = if let Some(layer_norm_embedding) = &self.layer_norm_embedding { x.apply(layer_norm_embedding) } else { x };
let x = x
let mut hidden_state = x
.apply_t(&self.dropout, train)
.transpose(0, 1);
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut next_decoder_cache: Option<Vec<(&LayerState, &LayerState)>> = if self.output_past { Some(vec!()) } else { None };
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(Vec::with_capacity(self.layers.len())) } else { None };
let mut next_decoder_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> = if self.output_past {
if old_layer_states.is_some() { old_layer_states } else { Some(vec!((None, None); self.layers.len())) }
} else {
None
};
let encoder_hidden_states = encoder_hidden_states.transpose(0, 1);
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter_mut();
let mut layers = self.layers.iter().enumerate();
loop {
match layers.next() {
Some(layer) => {
Some((layer_idx, layer)) => {
let layer_state = match &next_decoder_cache {
Some(values) => values[layer_idx].to_owned(),
None => (None, None)
};
let temp = layer.forward_t(&hidden_state,
&encoder_hidden_states,
encoder_padding_mask.as_ref(),
decoder_causal_mask,
decoder_padding_mask,
layer_state,
train);
hidden_state = temp.0;
attention_weights = temp.1;
@ -266,9 +269,7 @@ impl BartDecoder {
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(attention_weights.as_ref().unwrap().copy());
};
if let Some(cache) = next_decoder_cache.borrow_mut() {
cache.push((layer.self_attention.prev_state.as_ref().unwrap(), layer.encoder_attention.prev_state.as_ref().unwrap()));
};
if let Some(value) = &mut next_decoder_cache { value[layer_idx] = temp.2 };
}
None => break
};

View File

@ -72,8 +72,8 @@ impl EncoderLayer {
EncoderLayer { self_attention, self_attention_layer_norm, dropout, activation_dropout, activation, fc1, fc2, final_layer_norm }
}
pub fn forward_t(&mut self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, attention_weights) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, train);
pub fn forward_t(&self, x: &Tensor, encoder_padding_mask: Option<&Tensor>, train: bool) -> (Tensor, Option<Tensor>) {
let (output, attention_weights, _) = self.self_attention.forward_t(x, None, encoder_padding_mask, None, None, train);
let output: Tensor = output.apply_t(&self.dropout, train) + x;
let output = output.apply(&self.self_attention_layer_norm);
@ -117,7 +117,7 @@ impl BartEncoder {
None => false
};
let scale_embedding = match config.scale_embedding {
Some(value) => if value {(config.d_model as f64).sqrt()} else { 1.0 },
Some(value) => if value { (config.d_model as f64).sqrt() } else { 1.0 },
None => 1.0
};
@ -165,7 +165,7 @@ impl BartEncoder {
}
}
pub fn forward_t(&mut self,
pub fn forward_t(&self,
input_ids: &Tensor,
attention_mask: Option<&Tensor>,
embeddings: &nn::Embedding,
@ -188,7 +188,7 @@ impl BartEncoder {
let mut hidden_state = x.copy();
let mut attention_weights: Option<Tensor>;
let mut layers = self.layers.iter_mut();
let mut layers = self.layers.iter();
loop {
match layers.next() {

View File

@ -21,7 +21,7 @@ use tch::kind::Kind::Int64;
use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::Config;
use crate::pipelines::generation::LMHeadModel;
use crate::pipelines::generation::{LMHeadModel, Cache};
/// # GPT2 Pretrained model weight files
pub struct Gpt2ModelResources;
@ -413,7 +413,7 @@ impl LMHeadModel for GPT2LMHeadModel {
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
@ -433,7 +433,7 @@ impl LMHeadModel for GPT2LMHeadModel {
/// let (output, _, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
@ -445,28 +445,38 @@ impl LMHeadModel for GPT2LMHeadModel {
///
/// ```
///
fn forward_t(&mut self,
fn forward_t(&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
past,
all_hidden_states,
all_attentions) = self.transformer.forward_t(input_ids,
layer_past,
all_attentions) = match layer_past {
Cache::GPT2Cache(layer_past) => Ok(self.transformer.forward_t(input_ids,
&layer_past,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?),
Cache::None => Ok(self.transformer.forward_t(input_ids,
&None,
attention_mask,
token_type_ids,
position_ids,
input_embeds,
train)?;
train)?),
_ => Err("Cache not compatible with GPT2 model")
}?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, past, all_hidden_states, all_attentions))
Ok((lm_logits, None, Cache::GPT2Cache(past), all_hidden_states, all_attentions))
}
}

View File

@ -11,10 +11,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::bart::{BartModel, BartConfig};
use crate::bart::{BartModel, BartConfig, LayerState};
use tch::{Tensor, nn};
use std::borrow::BorrowMut;
use crate::pipelines::generation::LMHeadModel;
use crate::pipelines::generation::{LMHeadModel, Cache};
use tch::nn::Init;
/// # Marian Pretrained model weight files
@ -198,66 +197,62 @@ impl MarianForConditionalGeneration {
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::bart::{BartConfig, BartForConditionalGeneration};
/// use rust_bert::bart::{BartConfig};
/// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = BartConfig::from_file(config_path);
///# let mut bart_model: BartForConditionalGeneration = BartForConditionalGeneration::new(&vs.root(), &config, false);
///# let mut marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (decoder_output, encoder_hidden_states,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// bart_model
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
/// ```
///
pub fn forward_t(&mut self,
pub fn forward_t(&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>)>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool)
-> (Tensor, Tensor,
-> (Tensor, Tensor, Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>,
Option<Vec<Tensor>>, Option<Vec<Tensor>>)
{
let (decoder_outputs, encoder_hidden_states, _,
let (decoder_outputs, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions) =
self.borrow_mut().base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, train);
self.base_model.forward_t(input_ids, attention_mask, decoder_input_ids, encoder_outputs, decoder_attention_mask, old_layer_states, train);
let lm_logits = decoder_outputs.linear::<Tensor>(&self.base_model.embeddings.ws, None);
(lm_logits, encoder_hidden_states,
(lm_logits, encoder_hidden_states, decoder_cache,
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
}
pub(crate) fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }
pub fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
encoder_hidden_states
}
/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
self.get_base_model().reset_cache()
}
}
impl LMHeadModel for MarianForConditionalGeneration {
@ -292,58 +287,64 @@ impl LMHeadModel for MarianForConditionalGeneration {
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::bart::{BartConfig};
/// use rust_bert::marian::MarianForConditionalGeneration;
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = Gpt2Config::from_file(config_path);
///# let mut gpt2_model: GPT2LMHeadModel = GPT2LMHeadModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
/// for _ in 0..config.n_layer as usize {
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
/// }
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///# let config = BartConfig::from_file(config_path);
///# let marian_model = MarianForConditionalGeneration::new(&vs.root(), &config, false);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask = Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let (output, encoder_hidden_states, _, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
/// &None,
/// let (decoder_output, encoder_hidden_states, cache,
/// all_encoder_hidden_states, all_encoder_attentions,
/// all_decoder_hidden_states, all_decoder_attentions) = no_grad(|| {
/// marian_model
/// .forward_t(Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// &None,
/// false).unwrap()
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// false)
/// });
///
/// ```
///
fn forward_t(&mut self,
fn forward_t(&self,
input_ids: &Option<Tensor>,
_layer_past: &Option<Vec<Tensor>>,
cache: Cache,
attention_mask: &Option<Tensor>,
_token_type_ids: &Option<Tensor>,
_position_ids: &Option<Tensor>,
_input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (decoder_output, encoder_hidden_states, _, _, _, _, _) = self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
train);
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (decoder_output, encoder_hidden_states, new_cache, _, _, _, _) = match cache {
Cache::BARTCache(cached_layer_states) => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
cached_layer_states,
train),
Cache::None => self.base_model.forward_t(input_ids.as_ref(),
attention_mask.as_ref(),
decoder_input_ids.as_ref(),
Some((encoder_outputs.as_ref().unwrap().copy(), None, None)),
None,
None,
train),
_ => Err("Cache not compatible with Marian Model")?
};
let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.embeddings.ws, None) + &self.final_logits_bias;
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
Ok((lm_logits, Some(encoder_hidden_states), Cache::BARTCache(new_cache), None, None))
}
}

View File

@ -20,7 +20,7 @@ use std::borrow::BorrowMut;
use crate::common::linear::{LinearNoBias, linear_no_bias};
use crate::openai_gpt::transformer::Block;
use crate::gpt2::Gpt2Config;
use crate::pipelines::generation::LMHeadModel;
use crate::pipelines::generation::{LMHeadModel, Cache};
/// # GPT Pretrained model weight files
pub struct OpenAiGptModelResources;
@ -311,7 +311,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::Gpt2Config;
/// use rust_bert::openai_gpt::OpenAIGPTLMHeadModel;
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
@ -327,7 +327,7 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
/// let (output, _, _, hidden_states, attentions) = no_grad(|| {
/// gpt_model
/// .forward_t(&Some(input_tensor),
/// &None,
/// Cache::None,
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
@ -339,16 +339,16 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
///
/// ```
///
fn forward_t(&mut self,
fn forward_t(&self,
input_ids: &Option<Tensor>,
_layer_past: &Option<Vec<Tensor>>,
_layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
_encoder_outputs: Option<&Tensor>,
_decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output,
all_hidden_states,
all_attentions) = self.transformer.forward_t(input_ids,
@ -359,6 +359,6 @@ impl LMHeadModel for OpenAIGPTLMHeadModel {
train)?;
let lm_logits = output.apply(&self.lm_head);
Ok((lm_logits, None, None, all_hidden_states, all_attentions))
Ok((lm_logits, None, Cache::None, all_hidden_states, all_attentions))
}
}

View File

@ -67,7 +67,7 @@ use crate::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptModelResources, OpenAiGpt
use crate::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ModelResources, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources};
use crate::Config;
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::bart::{BartConfig, BartForConditionalGeneration, BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources};
use crate::bart::{BartConfig, BartForConditionalGeneration, BartModelResources, BartConfigResources, BartVocabResources, BartMergesResources, LayerState};
use crate::common::resources::{Resource, RemoteResource, download_resource};
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
use rust_tokenizers::preprocessing::vocab::marian_vocab::MarianVocab;
@ -247,7 +247,7 @@ impl OpenAIGenerator {
}
impl PrivateLanguageGenerator<OpenAIGPTLMHeadModel, OpenAiGptVocab, OpenAiGptTokenizer> for OpenAIGenerator {
fn get_model(&mut self) -> &mut OpenAIGPTLMHeadModel { &mut self.model }
fn get_model(&self) -> &OpenAIGPTLMHeadModel { &self.model }
fn get_tokenizer(&self) -> &OpenAiGptTokenizer { &self.tokenizer }
fn get_var_store(&self) -> &nn::VarStore { &self.var_store }
fn get_config(&self) -> &GenerateConfig { &self.generate_config }
@ -327,7 +327,7 @@ impl GPT2Generator {
}
impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT2Generator {
fn get_model(&mut self) -> &mut GPT2LMHeadModel { &mut self.model }
fn get_model(&self) -> &GPT2LMHeadModel { &self.model }
fn get_tokenizer(&self) -> &Gpt2Tokenizer { &self.tokenizer }
fn get_var_store(&self) -> &nn::VarStore { &self.var_store }
fn get_config(&self) -> &GenerateConfig { &self.generate_config }
@ -341,13 +341,19 @@ impl PrivateLanguageGenerator<GPT2LMHeadModel, Gpt2Vocab, Gpt2Tokenizer> for GPT
fn prepare_inputs_for_generation<'a>(&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Option<Vec<Tensor>>,
past: Cache,
_attention_mask: Tensor)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Option<Vec<Tensor>>) {
if past.is_some() {
(Some(input_ids.select(1, -1).unsqueeze(-1)), None, None, past)
} else {
(Some(input_ids), None, None, past)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::GPT2Cache(past) => {
if past.is_some() {
(Some(input_ids.select(1, -1).unsqueeze(-1)), None, None, Cache::GPT2Cache(past))
} else {
(Some(input_ids), None, None, Cache::GPT2Cache(None))
}
}
Cache::None => (Some(input_ids), None, None, Cache::GPT2Cache(None)),
_ => panic!("Cache type incompatible with GPT2")
}
}
}
@ -473,7 +479,7 @@ impl BartGenerator {
}
impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer> for BartGenerator {
fn get_model(&mut self) -> &mut BartForConditionalGeneration { &mut self.model }
fn get_model(&self) -> &BartForConditionalGeneration { &self.model }
fn get_tokenizer(&self) -> &RobertaTokenizer { &self.tokenizer }
fn get_var_store(&self) -> &nn::VarStore { &self.var_store }
fn get_config(&self) -> &GenerateConfig { &self.generate_config }
@ -492,17 +498,23 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
_past: Option<Vec<Tensor>>,
past: Cache,
_attention_mask: Tensor)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Option<Vec<Tensor>>) {
(None, encoder_outputs, Some(input_ids), None)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::BARTCache(past) => {
(None, encoder_outputs, Some(input_ids), Cache::BARTCache(past))
}
Cache::None => (None, encoder_outputs, Some(input_ids), Cache::BARTCache(None)),
_ => panic!("Cache type incompatible with BART")
}
}
fn encode_prompt_text(&self, prompt_text: Vec<&str>, max_len: u64, pad_token_id: Option<i64>) -> Tensor {
@ -536,20 +548,35 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(&mut self, _past: Option<Vec<Tensor>>, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None
};
for layer in self.get_model().get_base_model().get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
layer.get_encoder_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
match past {
Cache::BARTCache(old_cache_option) => {
match old_cache_option {
Some(old_cache) => {
let mut new_past = vec!();
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => Some(self_layer_state.reorder_cache(beam_indices)),
None => None
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => Some(encoder_layer_state.reorder_cache(beam_indices)),
None => None
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
};
}
None => { }
}
}
Cache::None => {},
_ => { panic!("Invalid cache for BART model"); }
};
(None, encoder_outputs)
}
fn reset_cache(&mut self) {
self.get_model().reset_cache();
encoder_outputs
}
}
@ -645,7 +672,7 @@ impl MarianGenerator {
}
impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, MarianTokenizer> for MarianGenerator {
fn get_model(&mut self) -> &mut MarianForConditionalGeneration { &mut self.model }
fn get_model(&self) -> &MarianForConditionalGeneration { &self.model }
fn get_tokenizer(&self) -> &MarianTokenizer { &self.tokenizer }
fn get_var_store(&self) -> &nn::VarStore { &self.var_store }
fn get_config(&self) -> &GenerateConfig { &self.generate_config }
@ -663,17 +690,23 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
}
}
fn encode(&mut self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
_past: Option<Vec<Tensor>>,
past: Cache,
_attention_mask: Tensor)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Option<Vec<Tensor>>) {
(None, encoder_outputs, Some(input_ids), None)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
match past {
Cache::BARTCache(past) => {
(None, encoder_outputs, Some(input_ids), Cache::BARTCache(past))
}
Cache::None => (None, encoder_outputs, Some(input_ids), Cache::BARTCache(None)),
_ => panic!("Cache type incompatible with Marian")
}
}
fn encode_prompt_text(&self, prompt_text: Vec<&str>, max_len: u64, pad_token_id: Option<i64>) -> Tensor {
@ -707,25 +740,47 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(&mut self, _past: Option<Vec<Tensor>>, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None
};
for layer in self.get_model().get_base_model().get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
layer.get_encoder_attention().prev_state.as_mut().unwrap().reorder_cache(beam_indices);
match past {
Cache::BARTCache(old_cache_option) => {
match old_cache_option {
Some(old_cache) => {
let mut new_past = vec!();
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => Some(self_layer_state.reorder_cache(beam_indices)),
None => None
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => Some(encoder_layer_state.reorder_cache(beam_indices)),
None => None
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
};
}
None => { }
}
}
Cache::None => {},
_ => { panic!("Invalid cache for BART model"); }
};
(None, encoder_outputs)
}
fn reset_cache(&mut self) {
self.get_model().reset_cache();
encoder_outputs
}
}
impl LanguageGenerator<MarianForConditionalGeneration, MarianVocab, MarianTokenizer> for MarianGenerator {}
#[derive(Debug)]
pub enum Cache {
GPT2Cache(Option<Vec<Tensor>>),
BARTCache(Option<Vec<(Option<LayerState>, Option<LayerState>)>>),
None,
}
mod private_generation_utils {
use rust_tokenizers::{Vocab, Tokenizer, TruncationStrategy};
use tch::{nn, Tensor, Device};
@ -733,12 +788,12 @@ mod private_generation_utils {
use std::collections::HashMap;
use tch::kind::Kind::{Int64, Float, Bool};
use std::cmp::{min, max};
use crate::pipelines::generation::{BeamHypotheses, GenerateConfig, LMHeadModel};
use crate::pipelines::generation::{BeamHypotheses, GenerateConfig, LMHeadModel, Cache};
use itertools::Itertools;
use super::ordered_float::OrderedFloat;
pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&mut self) -> &mut T;
fn get_model(&self) -> &T;
fn get_tokenizer(&self) -> &U;
fn get_var_store(&self) -> &nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
@ -749,16 +804,17 @@ mod private_generation_utils {
fn get_vocab_size(&self) -> i64;
fn get_decoder_start_id(&self) -> Option<i64>;
fn prepare_scores_for_generation(&self, _scores: &mut Tensor, _current_length: i64, _max_length: i64) {}
fn encode(&mut self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> { None }
fn encode(&self, _input_ids: &Tensor, _attention_mask: Option<&Tensor>) -> Option<Tensor> { None }
fn prepare_inputs_for_generation<'a>(&self,
input_ids: Tensor,
_encoder_outputs: Option<&'a Tensor>,
past: Option<Vec<Tensor>>,
past: Cache,
_attention_mask: Tensor)
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Option<Vec<Tensor>>) {
-> (Option<Tensor>, Option<&'a Tensor>, Option<Tensor>, Cache) {
(Some(input_ids), None, None, past)
}
@ -893,7 +949,7 @@ mod private_generation_utils {
}
}
fn generate_no_beam_search(&mut self, input_ids: Tensor, encoder_outputs: Option<Tensor>,
fn generate_no_beam_search(&self, input_ids: Tensor, encoder_outputs: Option<Tensor>,
cur_len: i64, min_length: i64, max_length: i64, do_sample: bool,
temperature: f64, top_k: i64, top_p: f64, repetition_penalty: f64, no_repeat_ngram_size: i64,
pad_token_id: Option<i64>, eos_token_ids: Option<Vec<i64>>,
@ -902,7 +958,7 @@ mod private_generation_utils {
let mut sentence_lengths: Tensor = Tensor::ones(&[batch_size], (Int64, self.get_var_store().device())) * max_length as i64;
let mut attention_mask = attention_mask.copy();
let mut input_ids = input_ids.copy();
let mut past: Option<Vec<Tensor>> = None;
let mut past: Cache = Cache::None;
let mut outputs: Tensor;
let mut current_length = cur_len;
@ -914,8 +970,9 @@ mod private_generation_utils {
encoder_outputs.as_ref(),
past,
attention_mask.copy());
let temp = self.get_model().forward_t(&prepared_input,
&prepared_past,
prepared_past,
&None,
&None,
&None,
@ -925,6 +982,7 @@ mod private_generation_utils {
false).unwrap();
outputs = temp.0;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
@ -1001,7 +1059,7 @@ mod private_generation_utils {
decoded
}
fn generate_beam_search(&mut self, input_ids: Tensor, encoder_outputs: Option<Tensor>,
fn generate_beam_search(&self, input_ids: Tensor, encoder_outputs: Option<Tensor>,
cur_len: i64, min_length: i64, max_length: i64, do_sample: bool, early_stopping: bool,
temperature: f64, top_k: i64, top_p: f64, repetition_penalty: f64, no_repeat_ngram_size: i64,
pad_token_id: Option<i64>, eos_token_ids: Option<Vec<i64>>,
@ -1019,7 +1077,7 @@ mod private_generation_utils {
let mut beam_scores = beam_scores.view_(&[-1]);
let mut beam_tokens: Tensor;
let mut beam_indices: Tensor;
let mut past: Option<Vec<Tensor>> = None;
let mut past: Cache = Cache::None;
let mut done = vec!(false; batch_size as usize);
let mut attention_mask = attention_mask.copy();
@ -1037,7 +1095,7 @@ mod private_generation_utils {
past,
attention_mask.copy());
let temp = self.get_model().forward_t(&prepared_input,
&prepared_past,
prepared_past,
&None,
&None,
&None,
@ -1046,7 +1104,6 @@ mod private_generation_utils {
&prepared_decoder_input,
false).unwrap();
outputs = temp.0;
encoder_outputs = temp.1;
past = temp.2;
let mut next_token_logits = outputs.select(1, -1);
@ -1157,9 +1214,9 @@ mod private_generation_utils {
input_ids = input_ids.index_select(0, &beam_indices);
input_ids = Tensor::cat(&[input_ids, beam_tokens.unsqueeze(1)], -1);
let temp_past = self.reorder_cache(past, encoder_outputs, &beam_indices);
past = temp_past.0;
encoder_outputs = temp_past.1;
encoder_outputs = self.reorder_cache(&mut past, encoder_outputs, &beam_indices);
// past = temp_past.0;
// encoder_outputs = temp_past.1;
if !self.is_encoder_decoder() {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
@ -1234,20 +1291,25 @@ mod private_generation_utils {
decoded
}
fn reorder_cache(&mut self, past: Option<Vec<Tensor>>, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor) -> (Option<Vec<Tensor>>, Option<Tensor>) {
fn reorder_cache(&self, past: &mut Cache, _encoder_outputs: Option<Tensor>, beam_indices: &Tensor)
-> Option<Tensor> {
match past {
Some(value) => {
let mut reordered_past = vec!();
for layer_past in value.iter() {
reordered_past.push(layer_past.index_select(1, beam_indices));
Cache::None => { None }
Cache::GPT2Cache(cached_decoder_state) => {
match cached_decoder_state {
Some(value) => {
// let mut reordered_past = vec!();
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
}
None => None
}
(Some(reordered_past), None)
}
None => (None, None)
Cache::BARTCache(_) => { panic!("Not implemented"); }
}
}
fn reset_cache(&mut self) {}
}
}
@ -1308,7 +1370,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
///# ;
///```
///
fn generate(&mut self, prompt_texts: Option<Vec<&str>>, attention_mask: Option<Tensor>)
fn generate(&self, prompt_texts: Option<Vec<&str>>, attention_mask: Option<Tensor>)
-> Vec<String> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
@ -1399,7 +1461,6 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
(input_ids, attention_mask)
};
self.reset_cache();
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(input_ids, encoder_outputs, cur_len, min_length as i64, max_length as i64, do_sample, early_stopping, temperature, top_k as i64, top_p, repetition_penalty,
@ -1523,7 +1584,7 @@ pub trait LMHeadModel {
///# use std::path::Path;
///# use tch::kind::Kind::{Int64, Double};
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
/// use rust_bert::pipelines::generation::LMHeadModel;
/// use rust_bert::pipelines::generation::{LMHeadModel, Cache};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
@ -1543,7 +1604,7 @@ pub trait LMHeadModel {
/// let (output, encoder_output, past, hidden_states, attentions) = no_grad(|| {
/// gpt2_model
/// .forward_t(&Some(input_tensor),
/// &Some(past),
/// Cache::GPT2Cache(Some(past)),
/// &Some(attention_mask),
/// &Some(token_type_ids),
/// &Some(position_ids),
@ -1555,14 +1616,14 @@ pub trait LMHeadModel {
///
/// ```
///
fn forward_t(&mut self,
fn forward_t(&self,
input_ids: &Option<Tensor>,
layer_past: &Option<Vec<Tensor>>,
layer_past: Cache,
attention_mask: &Option<Tensor>,
token_type_ids: &Option<Tensor>,
position_ids: &Option<Tensor>,
input_embeds: &Option<Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: &Option<Tensor>,
train: bool) -> Result<(Tensor, Option<Tensor>, Option<Vec<Tensor>>, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str>;
}
train: bool) -> Result<(Tensor, Option<Tensor>, Cache, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str>;
}

View File

@ -197,7 +197,7 @@ impl SummarizationModel {
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::summarization::SummarizationModel;
/// let mut model = SummarizationModel::new(Default::default())?;
/// let model = SummarizationModel::new(Default::default())?;
///
/// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
///from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
@ -227,7 +227,7 @@ impl SummarizationModel {
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
///
pub fn summarize(&mut self, texts: &[&str]) -> Vec<String> {
pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
self.model.generate(Some(texts.to_vec()), None)
}
}

View File

@ -374,7 +374,7 @@ impl TranslationModel {
/// use tch::Device;
///
/// let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::cuda_if_available());
/// let mut model = TranslationModel::new(translation_config)?;
/// let model = TranslationModel::new(translation_config)?;
///
/// let input = ["This is a sentence to be translated"];
///
@ -383,7 +383,7 @@ impl TranslationModel {
///# }
/// ```
///
pub fn translate(&mut self, texts: &[&str]) -> Vec<String> {
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts

View File

@ -23,7 +23,7 @@ fn bart_lm_model() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer: RobertaTokenizer = RobertaTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = BartConfig::from_file(config_path);
let mut bart_model = BartModel::new(&vs.root(), &config, false);
let bart_model = BartModel::new(&vs.root(), &config, false);
vs.load(weights_path)?;
// Define input
@ -49,6 +49,7 @@ fn bart_lm_model() -> failure::Fallible<()> {
None,
None,
None,
None,
false);
assert_eq!(output.size(), vec!(1, 6, 1024));
@ -68,7 +69,7 @@ fn bart_summarization_greedy() -> failure::Fallible<()> {
device: Device::Cpu,
..Default::default()
};
let mut model = SummarizationModel::new(summarization_config)?;
let model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
@ -113,7 +114,7 @@ fn bart_summarization_beam_search() -> failure::Fallible<()> {
device: Device::Cpu,
..Default::default()
};
let mut model = SummarizationModel::new(summarization_config)?;
let model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \

View File

@ -2,7 +2,7 @@ use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::Config;
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2VocabResources, Gpt2MergesResources, Gpt2ModelResources};
use rust_bert::pipelines::generation::LMHeadModel;
use rust_bert::pipelines::generation::{LMHeadModel, Cache};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
#[test]
@ -22,7 +22,7 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let mut gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -44,7 +44,7 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
&None,
Cache::None,
&None,
&None,
&None,
@ -57,9 +57,14 @@ fn distilgpt2_lm_model() -> failure::Fallible<()> {
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
assert_eq!(output.size(), vec!(1, 11, 50257));
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 11, 64));
}
_ => panic!("Wrong cache returned for GPT2")
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-48.7065)).abs() < 1e-4);
assert_eq!(next_word_id, 14104i64);
assert_eq!(next_word, String::from(" twelve"));

View File

@ -1,7 +1,7 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{Gpt2Tokenizer, TruncationStrategy, Tokenizer};
use rust_bert::Config;
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel};
use rust_bert::pipelines::generation::{GPT2Generator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel, Gpt2ConfigResources, Gpt2MergesResources, Gpt2VocabResources, Gpt2ModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
@ -22,7 +22,7 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), false);
let config = Gpt2Config::from_file(config_path);
let mut gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
let gpt2_model = GPT2LMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -44,7 +44,7 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
// Forward pass
let (output, _, past, _, _) = gpt2_model.forward_t(
&Some(input_tensor),
&None,
Cache::None,
&None,
&None,
&None,
@ -57,9 +57,14 @@ fn gpt2_lm_model() -> failure::Fallible<()> {
let next_word = tokenizer.decode(vec!(next_word_id), true, true);
assert_eq!(output.size(), vec!(1, 4, 50257));
assert!(past.is_some());
match past {
Cache::GPT2Cache(past) => {
assert!(past.is_some());
assert_eq!(past.as_ref().unwrap().len(), config.n_layer as usize);
assert_eq!(past.as_ref().unwrap()[0].size(), vec!(2, 1, config.n_head, 4, 64));
}
_ => panic!("Wrong cache returned for GPT2")
}
assert!((output.double_value(&[0, output.size()[1] - 1, next_word_id]) - (-69.4948)).abs() < 1e-4);
assert_eq!(next_word_id, 1936i64);
assert_eq!(next_word, String::from(" five"));
@ -89,7 +94,7 @@ fn gpt2_generation_greedy() -> failure::Fallible<()> {
repetition_penalty: 1.1,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The cat";
let output = model.generate(Some(vec!(input_context)), None);
@ -121,7 +126,7 @@ fn gpt2_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let output = model.generate(Some(vec!(input_context)), None);
@ -155,7 +160,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> failure::Fa
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat";
@ -193,7 +198,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
num_return_sequences: 3,
..Default::default()
};
let mut model = GPT2Generator::new(generate_config)?;
let model = GPT2Generator::new(generate_config)?;
let input_context_1 = "The dog";
let input_context_2 = "The cat was";

View File

@ -7,7 +7,7 @@ fn test_translation() -> failure::Fallible<()> {
// Set-up translation model
let translation_config = TranslationConfig::new(Language::EnglishToFrench, Device::Cpu);
let mut model = TranslationModel::new(translation_config)?;
let model = TranslationModel::new(translation_config)?;
let input_context_1 = "The quick brown fox jumps over the lazy dog";
let input_context_2 = "The dog did not wake up";
@ -16,7 +16,7 @@ fn test_translation() -> failure::Fallible<()> {
assert_eq!(output.len(), 2);
assert_eq!(output[0], " Le rapide renard brun saute sur le chien paresseux");
assert_eq!(output[1], " Le chien ne s'est pas réveillé");
assert_eq!(output[1], " Le chien ne s'est pas réveillé.");
Ok(())
}

View File

@ -1,7 +1,7 @@
use tch::{Device, nn, Tensor};
use rust_tokenizers::{TruncationStrategy, Tokenizer, OpenAiGptTokenizer};
use rust_bert::Config;
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel};
use rust_bert::pipelines::generation::{OpenAIGenerator, LanguageGenerator, GenerateConfig, LMHeadModel, Cache};
use rust_bert::gpt2::Gpt2Config;
use rust_bert::openai_gpt::{OpenAIGPTLMHeadModel, OpenAiGptConfigResources, OpenAiGptVocabResources, OpenAiGptMergesResources, OpenAiGptModelResources};
use rust_bert::resources::{RemoteResource, Resource, download_resource};
@ -23,7 +23,7 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer = OpenAiGptTokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
let config = Gpt2Config::from_file(config_path);
let mut openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
let openai_gpt = OpenAIGPTLMHeadModel::new(&vs.root(), &config);
vs.load(weights_path)?;
// Define input
@ -45,7 +45,7 @@ fn openai_gpt_lm_model() -> failure::Fallible<()> {
// Forward pass
let (output, _, _, _, _) = openai_gpt.forward_t(
&Some(input_tensor),
&None,
Cache::None,
&None,
&None,
&None,
@ -87,7 +87,7 @@ fn openai_gpt_generation_greedy() -> failure::Fallible<()> {
temperature: 1.1,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "It was an intense machine dialogue. ";
let output = model.generate(Some(vec!(input_context)), None);
@ -119,7 +119,7 @@ fn openai_gpt_generation_beam_search() -> failure::Fallible<()> {
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context = "The dog is";
let output = model.generate(Some(vec!(input_context)), None);
@ -153,7 +153,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> failu
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat";
@ -194,7 +194,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> failure:
num_return_sequences: 3,
..Default::default()
};
let mut model = OpenAIGenerator::new(generate_config)?;
let model = OpenAIGenerator::new(generate_config)?;
let input_context_1 = "The dog is";
let input_context_2 = "The cat was in";