Fix and validation of ProphetNet summarization

This commit is contained in:
Guillaume B 2021-01-18 12:41:58 +01:00
parent 6cfb2f1d54
commit 9c0edeebf1
6 changed files with 268 additions and 12 deletions

View File

@ -77,7 +77,6 @@ fn main() -> anyhow::Result<()> {
None,
None,
None,
None,
false,
)
.unwrap()

View File

@ -0,0 +1,77 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::prophetnet::{
ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use tch::Device;
fn main() -> anyhow::Result<()> {
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_CNN_DM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_CNN_DM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_CNN_DM,
));
let summarization_config = SummarizationConfig {
model_type: ModelType::ProphetNet,
model_resource: weights_resource,
config_resource,
vocab_resource: vocab_resource.clone(),
merges_resource: vocab_resource,
length_penalty: 1.2,
num_beams: 4,
no_repeat_ngram_size: 3,
device: Device::cuda_if_available(),
..Default::default()
};
let summarization_model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let _output = summarization_model.summarize(&input);
for sentence in _output {
println!("{}", sentence);
}
Ok(())
}

View File

@ -101,11 +101,12 @@ use crate::xlnet::{LayerState as XLNetLayerState, XLNetConfig, XLNetLMHeadModel}
use crate::Config;
use itertools::Itertools;
use rust_tokenizers::tokenizer::{
Gpt2Tokenizer, MarianTokenizer, OpenAiGptTokenizer, ReformerTokenizer, RobertaTokenizer,
T5Tokenizer, Tokenizer, TruncationStrategy, XLNetTokenizer,
Gpt2Tokenizer, MarianTokenizer, OpenAiGptTokenizer, ProphetNetTokenizer, ReformerTokenizer,
RobertaTokenizer, T5Tokenizer, Tokenizer, TruncationStrategy, XLNetTokenizer,
};
use rust_tokenizers::vocab::{
Gpt2Vocab, MarianVocab, OpenAiGptVocab, ReformerVocab, RobertaVocab, T5Vocab, Vocab, XLNetVocab,
Gpt2Vocab, MarianVocab, OpenAiGptVocab, ProphetNetVocab, ReformerVocab, RobertaVocab, T5Vocab,
Vocab, XLNetVocab,
};
use tch::kind::Kind::Int64;
use tch::{nn, no_grad, Device, Kind, Tensor};
@ -1876,12 +1877,12 @@ impl ProphetNetConditionalGenerator {
let model = ProphetNetForConditionalGeneration::new(&var_store.root(), &config)?;
var_store.load(weights_path)?;
let bos_token_id = Some(config.decoder_start_token_id);
let bos_token_id = Some(config.bos_token_id);
let eos_token_ids = Some(vec![config.eos_token_id]);
let pad_token_id = Some(config.pad_token_id);
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = None;
let decoder_start_id = Some(config.decoder_start_token_id);
Ok(ProphetNetConditionalGenerator {
model,
@ -1898,6 +1899,170 @@ impl ProphetNetConditionalGenerator {
}
}
impl
PrivateLanguageGenerator<
ProphetNetForConditionalGeneration,
ProphetNetVocab,
ProphetNetTokenizer,
> for ProphetNetConditionalGenerator
{
fn get_model(&self) -> &ProphetNetForConditionalGeneration {
&self.model
}
fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> &Option<i64> {
&self.bos_token_id
}
fn get_eos_ids(&self) -> &Option<Vec<i64>> {
&self.eos_token_ids
}
fn get_pad_id(&self) -> &Option<i64> {
&self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(
self.get_model()
.encode(Some(input_ids), attention_mask, None)
.unwrap(),
)
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> (
Option<Tensor>,
Option<Tensor>,
Option<&'a Tensor>,
Option<Tensor>,
Cache,
) {
match past {
Cache::ProphetNetCache(past) => (
None,
Some(attention_mask),
encoder_outputs,
Some(input_ids.narrow(1, -1, 1)),
Cache::ProphetNetCache(past),
),
Cache::None => (
None,
Some(attention_mask),
encoder_outputs,
Some(input_ids),
Cache::ProphetNetCache(None),
),
_ => panic!("Cache type incompatible with ProphetNet"),
}
}
fn encode_prompt_text<'a, S>(
&self,
prompt_text: S,
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
{
let tokens = self.get_tokenizer().encode_list(
prompt_text.as_ref(),
max_len as usize,
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self
.get_tokenizer()
.convert_tokens_to_ids(&[RobertaVocab::unknown_value()])[0],
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
let encoder_outputs = match encoder_outputs {
Some(value) => Some(value.index_select(0, beam_indices)),
None => None,
};
match past {
Cache::ProphetNetCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for ProphetNet model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<ProphetNetForConditionalGeneration, ProphetNetVocab, ProphetNetTokenizer>
for ProphetNetConditionalGenerator
{
}
#[derive(Debug)]
pub enum Cache {
GPT2Cache(Option<Vec<Tensor>>),

View File

@ -69,7 +69,7 @@ use crate::common::error::RustBertError;
use crate::common::resources::{RemoteResource, Resource};
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::{
BartGenerator, GenerateConfig, LanguageGenerator, T5Generator,
BartGenerator, GenerateConfig, LanguageGenerator, ProphetNetConditionalGenerator, T5Generator,
};
use itertools::Itertools;
use tch::{Device, Tensor};
@ -208,6 +208,8 @@ pub enum SummarizationOption {
Bart(BartGenerator),
/// Summarizer based on T5 model
T5(T5Generator),
/// Summarizer based on ProphetNet model
ProphetNet(ProphetNetConditionalGenerator),
}
impl SummarizationOption {
@ -217,8 +219,11 @@ impl SummarizationOption {
config.into(),
)?)),
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new(config.into())?,
)),
_ => Err(RustBertError::InvalidConfigurationError(format!(
"QuestionAnswering not implemented for {:?}!",
"Summarization not implemented for {:?}!",
config.model_type
))),
}
@ -229,6 +234,7 @@ impl SummarizationOption {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::T5(_) => ModelType::T5,
Self::ProphetNet(_) => ModelType::ProphetNet,
}
}
@ -244,6 +250,9 @@ impl SummarizationOption {
match *self {
Self::Bart(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
Self::T5(ref model) => model.generate(prompt_texts, attention_mask, None, None, None),
Self::ProphetNet(ref model) => {
model.generate(prompt_texts, attention_mask, None, None, None)
}
}
}
}

View File

@ -642,9 +642,15 @@ impl ProphetNetNgramAttention {
.repeat(&[1, self.num_attention_heads, 1])
.view([-1, *main_relative_position_buckets.size().last().unwrap()]);
let mut new_shape = attention_weights
.size()
.into_iter()
.take(2)
.collect::<Vec<i64>>();
new_shape.push(-1);
rel_pos_embeddings
.gather(1, &main_relative_position_buckets, false)
.view([self.num_attention_heads, sequence_length, -1])
.view(new_shape.as_slice())
}
fn get_predict_relative_pos_embeddings(
@ -742,7 +748,7 @@ pub(crate) fn compute_relative_buckets(
let max_exact = num_buckets / 2;
let is_small = inverse_relative_positions.lt(max_exact);
let max_exact_f64 = max_exact as f64;
let val_if_large = (inverse_relative_positions.totype(Kind::Float) / max_exact_f64).log()
let val_if_large = (inverse_relative_positions.totype(Kind::Float) / max_exact_f64).log2()
/ (max_distance as f64 / max_exact_f64).log2()
* (num_buckets as f64 - max_exact_f64)
+ max_exact_f64;

View File

@ -49,7 +49,7 @@ impl ProphetNetConfigResources {
);
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
"prophetnet-large-uncased/config",
"prophetnet-large-uncased-cnndm/config",
"https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/config.json",
);
}
@ -62,7 +62,7 @@ impl ProphetNetVocabResources {
);
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_CNN_DM: (&'static str, &'static str) = (
"prophetnet-large-uncased/vocab",
"prophetnet-large-uncased-cnndm/vocab",
"https://huggingface.co/microsoft/prophetnet-large-uncased-cnndm/resolve/main/prophetnet.tokenizer",
);
}