Initial single turn conversation

This commit is contained in:
Guillaume B 2020-06-25 19:32:36 +02:00
parent 30528ca973
commit 9081bc3318
5 changed files with 215 additions and 203 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.7.8"
version = "0.7.9"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"

27
examples/conversation.rs Normal file
View File

@ -0,0 +1,27 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use rust_bert::pipelines::conversation::ConversationModel;
fn main() -> failure::Fallible<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let input = ["Hello, how are you? <|endoftext|>"];
let output = conversation_model.reply(&input);
println!("{:?}", output);
Ok(())
}

View File

@ -1,200 +1,185 @@
// // Copyright 2019-present Microsoft
// // Copyright 2020-present, the HuggingFace Inc. team.
// // Copyright 2020 Guillaume Becquin
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// // http://www.apache.org/licenses/LICENSE-2.0
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
//
// /// # Disclaimer
// /// This repository aims to facilitate research in large-scale pre-training for conversational data.
// /// This toolkit contains only part of the modeling machinery needed to actually produce a model
// /// weight file in a running dialog. On its own, this model provides only information about the
// /// weights of various text spans; in order for a researcher to actually use it, they will need
// /// to bring conversational data of their own and decode the response generation from the pretrained
// /// system. Neither the author of this repository or Microsoft are responsible for any generation
// /// from the 3rd party utilization of the pretrained system.
// ///
// ///
// ///
// ///
// use crate::bart::{
// BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
// };
// use crate::common::resources::{RemoteResource, Resource};
// use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
// use tch::Device;
//
// /// # Configuration for multi-turn classification
// /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
// /// different set of default parameters and sets the device to place the model on.
// pub struct ConversationConfig {
// /// Model weights resource (default: DialoGPT-medium)
// pub model_resource: Resource,
// /// Config resource (default: DialoGPT-medium)
// pub config_resource: Resource,
// /// Vocab resource (default: DialoGPT-medium)
// pub vocab_resource: Resource,
// /// Merges resource (default: DialoGPT-medium)
// pub merges_resource: Resource,
// /// Minimum sequence length (default: 0)
// pub min_length: u64,
// /// Maximum sequence length (default: 20)
// pub max_length: u64,
// /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
// pub do_sample: bool,
// /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
// pub early_stopping: bool,
// /// Number of beams for beam search (default: 5)
// pub num_beams: u64,
// /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
// pub temperature: f64,
// /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
// pub top_k: u64,
// /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
// pub top_p: f64,
// /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
// pub repetition_penalty: f64,
// /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
// pub length_penalty: f64,
// /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
// pub no_repeat_ngram_size: u64,
// /// Number of sequences to return for each prompt text (default: 1)
// pub num_return_sequences: u64,
// /// Device to place the model on (default: CUDA/GPU when available)
// pub device: Device,
// }
//
// impl Default for ConversationConfig {
// fn default() -> ConversationConfig {
// ConversationConfig {
// model_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartModelResources::BART_CNN,
// )),
// config_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartConfigResources::BART_CNN,
// )),
// vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartVocabResources::BART_CNN,
// )),
// merges_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartMergesResources::BART_CNN,
// )),
// min_length: 56,
// max_length: 142,
// do_sample: false,
// early_stopping: false,
// num_beams: 3,
// temperature: 1.0,
// top_k: 50,
// top_p: 1.0,
// repetition_penalty: 1.0,
// length_penalty: 1.0,
// no_repeat_ngram_size: 3,
// num_return_sequences: 1,
// device: Device::cuda_if_available(),
// }
// }
// }
//
// /// # SummarizationModel to perform summarization
// pub struct SummarizationModel {
// model: BartGenerator,
// }
//
// impl SummarizationModel {
// /// Build a new `SummarizationModel`
// ///
// /// # Arguments
// ///
// /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
// ///
// /// # Example
// ///
// /// ```no_run
// /// # fn main() -> failure::Fallible<()> {
// /// use rust_bert::pipelines::summarization::SummarizationModel;
// ///
// /// let mut summarization_model = SummarizationModel::new(Default::default())?;
// /// # Ok(())
// /// # }
// /// ```
// pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
// let generate_config = GenerateConfig {
// model_resource: summarization_config.model_resource,
// config_resource: summarization_config.config_resource,
// merges_resource: summarization_config.merges_resource,
// vocab_resource: summarization_config.vocab_resource,
// min_length: summarization_config.min_length,
// max_length: summarization_config.max_length,
// do_sample: summarization_config.do_sample,
// early_stopping: summarization_config.early_stopping,
// num_beams: summarization_config.num_beams,
// temperature: summarization_config.temperature,
// top_k: summarization_config.top_k,
// top_p: summarization_config.top_p,
// repetition_penalty: summarization_config.repetition_penalty,
// length_penalty: summarization_config.length_penalty,
// no_repeat_ngram_size: summarization_config.no_repeat_ngram_size,
// num_return_sequences: summarization_config.num_return_sequences,
// device: summarization_config.device,
// };
//
// let model = BartGenerator::new(generate_config)?;
//
// Ok(SummarizationModel { model })
// }
//
// /// Summarize texts provided
// ///
// /// # Arguments
// ///
// /// * `input` - `&[&str]` Array of texts to summarize.
// ///
// /// # Returns
// /// * `Vec<String>` Summarized texts
// ///
// /// # Example
// ///
// /// ```no_run
// /// # fn main() -> failure::Fallible<()> {
// /// use rust_bert::pipelines::generation::LanguageGenerator;
// /// use rust_bert::pipelines::summarization::SummarizationModel;
// /// let model = SummarizationModel::new(Default::default())?;
// ///
// /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
// /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
// /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
// /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
// /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
// /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
// /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
// /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
// /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
// /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
// /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
// /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
// /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
// /// said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors.
// /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
// /// a potentially habitable planet, but further observations will be required to say for sure. \"
// /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
// /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
// /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
// /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
// /// about exoplanets like K2-18b."];
// ///
// /// let output = model.summarize(&input);
// /// # Ok(())
// /// # }
// /// ```
// /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
// pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
// self.model.generate(Some(texts.to_vec()), None)
// }
// }
// Copyright 2019-present Microsoft
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// # Disclaimer
/// This repository aims to facilitate research in large-scale pre-training for conversational data.
/// This toolkit contains only part of the modeling machinery needed to actually produce a model
/// weight file in a running dialog. On its own, this model provides only information about the
/// weights of various text spans; in order for a researcher to actually use it, they will need
/// to bring conversational data of their own and decode the response generation from the pretrained
/// system. Neither the author of this repository or Microsoft are responsible for any generation
/// from the 3rd party utilization of the pretrained system.
///
///
///
///
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
use tch::Device;
/// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
pub struct ConversationConfig {
/// Model weights resource (default: DialoGPT-medium)
pub model_resource: Resource,
/// Config resource (default: DialoGPT-medium)
pub config_resource: Resource,
/// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Resource,
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
pub repetition_penalty: f64,
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
impl Default for ConversationConfig {
fn default() -> ConversationConfig {
ConversationConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
)),
min_length: 0,
max_length: 1000,
do_sample: true,
early_stopping: false,
num_beams: 1,
temperature: 1.0,
top_k: 50,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 3,
num_return_sequences: 1,
device: Device::cuda_if_available(),
}
}
}
/// # Conversation model
pub struct ConversationModel {
model: GPT2Generator,
}
impl ConversationModel {
/// Build a new `ConversationModel`
///
/// # Arguments
///
/// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::conversation::ConversationModel;
///
/// let conversation_model = ConversationModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
pub fn new(conversation_config: ConversationConfig) -> failure::Fallible<ConversationModel> {
let generate_config = GenerateConfig {
model_resource: conversation_config.model_resource,
config_resource: conversation_config.config_resource,
merges_resource: conversation_config.merges_resource,
vocab_resource: conversation_config.vocab_resource,
min_length: conversation_config.min_length,
max_length: conversation_config.max_length,
do_sample: conversation_config.do_sample,
early_stopping: conversation_config.early_stopping,
num_beams: conversation_config.num_beams,
temperature: conversation_config.temperature,
top_k: conversation_config.top_k,
top_p: conversation_config.top_p,
repetition_penalty: conversation_config.repetition_penalty,
length_penalty: conversation_config.length_penalty,
no_repeat_ngram_size: conversation_config.no_repeat_ngram_size,
num_return_sequences: conversation_config.num_return_sequences,
device: conversation_config.device,
};
let model = GPT2Generator::new(generate_config)?;
Ok(ConversationModel { model })
}
/// Perform a multi-turn conversation based on user input
///
/// # Arguments
///
/// * `input` - `&[&str]` Array of user input texts.
///
/// # Returns
/// * `Vec<String>` Responses from the model for each input
///
/// # Example
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// use rust_bert::pipelines::conversation::ConversationModel;
/// let model = ConversationModel::new(Default::default())?;
///
/// let input = ["Hello, how are you?"];
///
/// let output = model.reply(&input);
/// # Ok(())
/// # }
/// ```
pub fn reply(&self, texts: &[&str]) -> Vec<String> {
// ToDo: add possibility to pass a History object as an input (or create a History) containing a Cache object
// ToDo: move encoding step to this method to handle teh <eos> token addition
// ToDo: create a `generate` sub-function that takes input ids & a Option<Cache> as an input
// ToDo: update base `generate` function to perform some preparation steps and then delegate to the lower level `generate` taking input ids & cache as input
// ToDo: update return of function to return a Vec<String> and a History
self.model.generate(Some(texts.to_vec()), None)
}
}

View File

@ -230,7 +230,7 @@
//!
pub mod common;
// pub mod conversation;
pub mod conversation;
pub mod generation;
pub mod ner;
pub mod question_answering;

View File

@ -35,7 +35,7 @@ shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps['transformer.' + k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
nps[k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
if k == 'wte.weight':
nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)