From 9081bc3318dc1369be3b91100f5bea0cba5b1760 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 25 Jun 2020 19:32:36 +0200 Subject: [PATCH] Initial single turn conversation --- Cargo.toml | 2 +- examples/conversation.rs | 27 ++ src/pipelines/conversation.rs | 385 +++++++++--------- src/pipelines/mod.rs | 2 +- .../download-dependencies_dialogpt-medium.py | 2 +- 5 files changed, 215 insertions(+), 203 deletions(-) create mode 100644 examples/conversation.rs diff --git a/Cargo.toml b/Cargo.toml index 1573794..cd94471 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.7.8" +version = "0.7.9" authors = ["Guillaume Becquin "] edition = "2018" description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)" diff --git a/examples/conversation.rs b/examples/conversation.rs new file mode 100644 index 0000000..5ccfe10 --- /dev/null +++ b/examples/conversation.rs @@ -0,0 +1,27 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate failure; + +use rust_bert::pipelines::conversation::ConversationModel; + +fn main() -> failure::Fallible<()> { + let conversation_model = ConversationModel::new(Default::default())?; + + let input = ["Hello, how are you? <|endoftext|>"]; + + let output = conversation_model.reply(&input); + + println!("{:?}", output); + + Ok(()) +} diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index f993018..4a4dd58 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -1,200 +1,185 @@ -// // Copyright 2019-present Microsoft -// // Copyright 2020-present, the HuggingFace Inc. team. -// // Copyright 2020 Guillaume Becquin -// // Licensed under the Apache License, Version 2.0 (the "License"); -// // you may not use this file except in compliance with the License. -// // You may obtain a copy of the License at -// // http://www.apache.org/licenses/LICENSE-2.0 -// // Unless required by applicable law or agreed to in writing, software -// // distributed under the License is distributed on an "AS IS" BASIS, -// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// // See the License for the specific language governing permissions and -// // limitations under the License. -// -// /// # Disclaimer -// /// This repository aims to facilitate research in large-scale pre-training for conversational data. -// /// This toolkit contains only part of the modeling machinery needed to actually produce a model -// /// weight file in a running dialog. On its own, this model provides only information about the -// /// weights of various text spans; in order for a researcher to actually use it, they will need -// /// to bring conversational data of their own and decode the response generation from the pretrained -// /// system. Neither the author of this repository or Microsoft are responsible for any generation -// /// from the 3rd party utilization of the pretrained system. -// /// -// /// -// /// -// /// -// use crate::bart::{ -// BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources, -// }; -// use crate::common::resources::{RemoteResource, Resource}; -// use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator}; -// use tch::Device; -// -// /// # Configuration for multi-turn classification -// /// Contains information regarding the model to load, mirrors the GenerationConfig, with a -// /// different set of default parameters and sets the device to place the model on. -// pub struct ConversationConfig { -// /// Model weights resource (default: DialoGPT-medium) -// pub model_resource: Resource, -// /// Config resource (default: DialoGPT-medium) -// pub config_resource: Resource, -// /// Vocab resource (default: DialoGPT-medium) -// pub vocab_resource: Resource, -// /// Merges resource (default: DialoGPT-medium) -// pub merges_resource: Resource, -// /// Minimum sequence length (default: 0) -// pub min_length: u64, -// /// Maximum sequence length (default: 20) -// pub max_length: u64, -// /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) -// pub do_sample: bool, -// /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) -// pub early_stopping: bool, -// /// Number of beams for beam search (default: 5) -// pub num_beams: u64, -// /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0) -// pub temperature: f64, -// /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0) -// pub top_k: u64, -// /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9) -// pub top_p: f64, -// /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0) -// pub repetition_penalty: f64, -// /// Exponential penalty based on the length of the hypotheses generated (default: 1.0) -// pub length_penalty: f64, -// /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3) -// pub no_repeat_ngram_size: u64, -// /// Number of sequences to return for each prompt text (default: 1) -// pub num_return_sequences: u64, -// /// Device to place the model on (default: CUDA/GPU when available) -// pub device: Device, -// } -// -// impl Default for ConversationConfig { -// fn default() -> ConversationConfig { -// ConversationConfig { -// model_resource: Resource::Remote(RemoteResource::from_pretrained( -// BartModelResources::BART_CNN, -// )), -// config_resource: Resource::Remote(RemoteResource::from_pretrained( -// BartConfigResources::BART_CNN, -// )), -// vocab_resource: Resource::Remote(RemoteResource::from_pretrained( -// BartVocabResources::BART_CNN, -// )), -// merges_resource: Resource::Remote(RemoteResource::from_pretrained( -// BartMergesResources::BART_CNN, -// )), -// min_length: 56, -// max_length: 142, -// do_sample: false, -// early_stopping: false, -// num_beams: 3, -// temperature: 1.0, -// top_k: 50, -// top_p: 1.0, -// repetition_penalty: 1.0, -// length_penalty: 1.0, -// no_repeat_ngram_size: 3, -// num_return_sequences: 1, -// device: Device::cuda_if_available(), -// } -// } -// } -// -// /// # SummarizationModel to perform summarization -// pub struct SummarizationModel { -// model: BartGenerator, -// } -// -// impl SummarizationModel { -// /// Build a new `SummarizationModel` -// /// -// /// # Arguments -// /// -// /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU) -// /// -// /// # Example -// /// -// /// ```no_run -// /// # fn main() -> failure::Fallible<()> { -// /// use rust_bert::pipelines::summarization::SummarizationModel; -// /// -// /// let mut summarization_model = SummarizationModel::new(Default::default())?; -// /// # Ok(()) -// /// # } -// /// ``` -// pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible { -// let generate_config = GenerateConfig { -// model_resource: summarization_config.model_resource, -// config_resource: summarization_config.config_resource, -// merges_resource: summarization_config.merges_resource, -// vocab_resource: summarization_config.vocab_resource, -// min_length: summarization_config.min_length, -// max_length: summarization_config.max_length, -// do_sample: summarization_config.do_sample, -// early_stopping: summarization_config.early_stopping, -// num_beams: summarization_config.num_beams, -// temperature: summarization_config.temperature, -// top_k: summarization_config.top_k, -// top_p: summarization_config.top_p, -// repetition_penalty: summarization_config.repetition_penalty, -// length_penalty: summarization_config.length_penalty, -// no_repeat_ngram_size: summarization_config.no_repeat_ngram_size, -// num_return_sequences: summarization_config.num_return_sequences, -// device: summarization_config.device, -// }; -// -// let model = BartGenerator::new(generate_config)?; -// -// Ok(SummarizationModel { model }) -// } -// -// /// Summarize texts provided -// /// -// /// # Arguments -// /// -// /// * `input` - `&[&str]` Array of texts to summarize. -// /// -// /// # Returns -// /// * `Vec` Summarized texts -// /// -// /// # Example -// /// -// /// ```no_run -// /// # fn main() -> failure::Fallible<()> { -// /// use rust_bert::pipelines::generation::LanguageGenerator; -// /// use rust_bert::pipelines::summarization::SummarizationModel; -// /// let model = SummarizationModel::new(Default::default())?; -// /// -// /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists -// /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team -// /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, -// /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's -// /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, -// /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet -// /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, -// /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere -// /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software -// /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, -// /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. -// /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" -// /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", -// /// said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors. -// /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being -// /// a potentially habitable planet, but further observations will be required to say for sure. \" -// /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger -// /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year -// /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space -// /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more -// /// about exoplanets like K2-18b."]; -// /// -// /// let output = model.summarize(&input); -// /// # Ok(()) -// /// # } -// /// ``` -// /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)) -// pub fn summarize(&self, texts: &[&str]) -> Vec { -// self.model.generate(Some(texts.to_vec()), None) -// } -// } +// Copyright 2019-present Microsoft +// Copyright 2020-present, the HuggingFace Inc. team. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// # Disclaimer +/// This repository aims to facilitate research in large-scale pre-training for conversational data. +/// This toolkit contains only part of the modeling machinery needed to actually produce a model +/// weight file in a running dialog. On its own, this model provides only information about the +/// weights of various text spans; in order for a researcher to actually use it, they will need +/// to bring conversational data of their own and decode the response generation from the pretrained +/// system. Neither the author of this repository or Microsoft are responsible for any generation +/// from the 3rd party utilization of the pretrained system. +/// +/// +/// +/// +use crate::common::resources::{RemoteResource, Resource}; +use crate::gpt2::{ + Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources, +}; +use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator}; +use tch::Device; + +/// # Configuration for multi-turn classification +/// Contains information regarding the model to load, mirrors the GenerationConfig, with a +/// different set of default parameters and sets the device to place the model on. +pub struct ConversationConfig { + /// Model weights resource (default: DialoGPT-medium) + pub model_resource: Resource, + /// Config resource (default: DialoGPT-medium) + pub config_resource: Resource, + /// Vocab resource (default: DialoGPT-medium) + pub vocab_resource: Resource, + /// Merges resource (default: DialoGPT-medium) + pub merges_resource: Resource, + /// Minimum sequence length (default: 0) + pub min_length: u64, + /// Maximum sequence length (default: 20) + pub max_length: u64, + /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true) + pub do_sample: bool, + /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false) + pub early_stopping: bool, + /// Number of beams for beam search (default: 5) + pub num_beams: u64, + /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0) + pub temperature: f64, + /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0) + pub top_k: u64, + /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9) + pub top_p: f64, + /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0) + pub repetition_penalty: f64, + /// Exponential penalty based on the length of the hypotheses generated (default: 1.0) + pub length_penalty: f64, + /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3) + pub no_repeat_ngram_size: u64, + /// Number of sequences to return for each prompt text (default: 1) + pub num_return_sequences: u64, + /// Device to place the model on (default: CUDA/GPU when available) + pub device: Device, +} + +impl Default for ConversationConfig { + fn default() -> ConversationConfig { + ConversationConfig { + model_resource: Resource::Remote(RemoteResource::from_pretrained( + Gpt2ModelResources::DIALOGPT_MEDIUM, + )), + config_resource: Resource::Remote(RemoteResource::from_pretrained( + Gpt2ConfigResources::DIALOGPT_MEDIUM, + )), + vocab_resource: Resource::Remote(RemoteResource::from_pretrained( + Gpt2VocabResources::DIALOGPT_MEDIUM, + )), + merges_resource: Resource::Remote(RemoteResource::from_pretrained( + Gpt2MergesResources::DIALOGPT_MEDIUM, + )), + min_length: 0, + max_length: 1000, + do_sample: true, + early_stopping: false, + num_beams: 1, + temperature: 1.0, + top_k: 50, + top_p: 0.9, + repetition_penalty: 1.0, + length_penalty: 1.0, + no_repeat_ngram_size: 3, + num_return_sequences: 1, + device: Device::cuda_if_available(), + } + } +} + +/// # Conversation model +pub struct ConversationModel { + model: GPT2Generator, +} + +impl ConversationModel { + /// Build a new `ConversationModel` + /// + /// # Arguments + /// + /// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU) + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> failure::Fallible<()> { + /// use rust_bert::pipelines::conversation::ConversationModel; + /// + /// let conversation_model = ConversationModel::new(Default::default())?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(conversation_config: ConversationConfig) -> failure::Fallible { + let generate_config = GenerateConfig { + model_resource: conversation_config.model_resource, + config_resource: conversation_config.config_resource, + merges_resource: conversation_config.merges_resource, + vocab_resource: conversation_config.vocab_resource, + min_length: conversation_config.min_length, + max_length: conversation_config.max_length, + do_sample: conversation_config.do_sample, + early_stopping: conversation_config.early_stopping, + num_beams: conversation_config.num_beams, + temperature: conversation_config.temperature, + top_k: conversation_config.top_k, + top_p: conversation_config.top_p, + repetition_penalty: conversation_config.repetition_penalty, + length_penalty: conversation_config.length_penalty, + no_repeat_ngram_size: conversation_config.no_repeat_ngram_size, + num_return_sequences: conversation_config.num_return_sequences, + device: conversation_config.device, + }; + + let model = GPT2Generator::new(generate_config)?; + + Ok(ConversationModel { model }) + } + + /// Perform a multi-turn conversation based on user input + /// + /// # Arguments + /// + /// * `input` - `&[&str]` Array of user input texts. + /// + /// # Returns + /// * `Vec` Responses from the model for each input + /// + /// # Example + /// + /// ```no_run + /// # fn main() -> failure::Fallible<()> { + /// use rust_bert::pipelines::generation::LanguageGenerator; + /// use rust_bert::pipelines::conversation::ConversationModel; + /// let model = ConversationModel::new(Default::default())?; + /// + /// let input = ["Hello, how are you?"]; + /// + /// let output = model.reply(&input); + /// # Ok(()) + /// # } + /// ``` + pub fn reply(&self, texts: &[&str]) -> Vec { + // ToDo: add possibility to pass a History object as an input (or create a History) containing a Cache object + // ToDo: move encoding step to this method to handle teh token addition + // ToDo: create a `generate` sub-function that takes input ids & a Option as an input + // ToDo: update base `generate` function to perform some preparation steps and then delegate to the lower level `generate` taking input ids & cache as input + // ToDo: update return of function to return a Vec and a History + + self.model.generate(Some(texts.to_vec()), None) + } +} diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index 1ce4121..d934a8a 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -230,7 +230,7 @@ //! pub mod common; -// pub mod conversation; +pub mod conversation; pub mod generation; pub mod ner; pub mod question_answering; diff --git a/utils/download-dependencies_dialogpt-medium.py b/utils/download-dependencies_dialogpt-medium.py index dd07284..28dbeea 100644 --- a/utils/download-dependencies_dialogpt-medium.py +++ b/utils/download-dependencies_dialogpt-medium.py @@ -35,7 +35,7 @@ shutil.copy(temp_weights, model_path) weights = torch.load(temp_weights, map_location='cpu') nps = {} for k, v in weights.items(): - nps['transformer.' + k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32) + nps[k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32) if k == 'wte.weight': nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)