Merge pull request #57 from guillaume-be/multiturn_conversation

Multiturn conversation
This commit is contained in:
guillaume-be 2020-06-28 11:56:15 +02:00 committed by GitHub
commit d076ec6f77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1183 additions and 37 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.7.8"
version = "0.7.9"
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
@ -41,4 +41,5 @@ ordered-float = "1.0.2"
csv = "1.1.3"
reqwest = "0.10.4"
lazy_static = "1.4.0"
uuid = {version="0.8.1", features =["v4"]}
tokio = { version = "0.2.21", features = ["full"] }

View File

@ -25,6 +25,11 @@ Translation | | | | | |✅| |✅ | |
## Ready-to-use pipelines
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
**Disclaimer**
The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
#### 1. Question Answering
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
@ -106,7 +111,30 @@ This is the first such discovery in a planet in its star's habitable zone.
The planet is not too hot and not too cold for liquid water to exist."
```
#### 4. Natural Language Generation
#### 4. Dialogue Model
Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
This pipeline allows the generation of single or multi-turn conversations between a human and a model.
The DialoGPT's page states that
> The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
> under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
```rust
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
let conversation_model = ConversationModel::new(Default::default());
let mut conversation_manager = ConversationManager::new();
let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
let output = conversation_model.generate_responses(&mut conversation_manager);
```
Example output:
```
"The Big Lebowski."
```
#### 5. Natural Language Generation
Generate language based on a prompt. GPT2 and GPT available as base models.
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -133,7 +161,7 @@ Example output:
]
```
#### 5. Sentiment analysis
#### 6. Sentiment analysis
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -157,7 +185,7 @@ Output:
]
```
#### 6. Named Entity Recognition
#### 7. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
```rust
let ner_model = NERModel::new(default::default())?;

43
examples/conversation.rs Normal file
View File

@ -0,0 +1,43 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate failure;
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
fn main() -> failure::Fallible<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =
conversation_manager.create("Going to the movies tonight - any suggestions?");
let _conversation_2_id = conversation_manager.create("What's the last book you have read?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
Ok(())
}

View File

@ -276,6 +276,27 @@ fn download_albert_base_v2() -> failure::Fallible<()> {
Ok(())
}
fn _download_dialogpt() -> failure::Fallible<()> {
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&merges_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}
fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();

View File

@ -61,6 +61,11 @@ impl Gpt2ModelResources {
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model.ot",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
);
}
impl Gpt2ConfigResources {
@ -89,6 +94,11 @@ impl Gpt2ConfigResources {
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config.json",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
);
}
impl Gpt2VocabResources {
@ -117,6 +127,11 @@ impl Gpt2VocabResources {
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
);
}
impl Gpt2MergesResources {
@ -145,6 +160,11 @@ impl Gpt2MergesResources {
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
);
}
#[allow(non_camel_case_types)]

View File

@ -9,6 +9,7 @@
//! - Ready-to-use NLP pipelines for:
//! - Translation
//! - Summarization
//! - Multi-turn dialogue
//! - Sentiment Analysis
//! - Named Entity Recognition
//! - Question-Answering

View File

@ -0,0 +1,797 @@
// Copyright 2019-present Microsoft
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Multi-turn dialogue
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
//!
//!
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/dialgpt-medium
//!
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
//! let conversation_model = ConversationModel::new(Default::default())?;
//! let mut conversation_manager = ConversationManager::new();
//!
//! let conversation_id =
//! conversation_manager.create("Going to the movies tonight - any suggestions?");
//! let output = conversation_model.generate_responses(&mut conversation_manager);
//! # Ok(())
//! # }
//! ```
//!
//! Example output: \
//! ```no_run
//! # let output =
//! "The Big Lebowski."
//! # ;
//! ```
//!
//! # Disclaimer
//! The authors of this repository are not responsible for any generation
//! from the 3rd party utilization of the pretrained system.
//!
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
};
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
use rust_tokenizers::{Tokenizer, TruncationStrategy};
use std::collections::HashMap;
use tch::{Device, Tensor};
use uuid::Uuid;
/// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
pub struct ConversationConfig {
/// Model weights resource (default: DialoGPT-medium)
pub model_resource: Resource,
/// Config resource (default: DialoGPT-medium)
pub config_resource: Resource,
/// Vocab resource (default: DialoGPT-medium)
pub vocab_resource: Resource,
/// Merges resource (default: DialoGPT-medium)
pub merges_resource: Resource,
/// Minimum sequence length (default: 0)
pub min_length: u64,
/// Maximum sequence length (default: 20)
pub max_length: u64,
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
pub do_sample: bool,
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
pub early_stopping: bool,
/// Number of beams for beam search (default: 5)
pub num_beams: u64,
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
pub temperature: f64,
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
pub top_k: u64,
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
pub top_p: f64,
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
pub repetition_penalty: f64,
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
pub length_penalty: f64,
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
pub no_repeat_ngram_size: u64,
/// Number of sequences to return for each prompt text (default: 1)
pub num_return_sequences: u64,
/// Device to place the model on (default: CUDA/GPU when available)
pub device: Device,
}
impl Default for ConversationConfig {
fn default() -> ConversationConfig {
ConversationConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ModelResources::DIALOGPT_MEDIUM,
)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2VocabResources::DIALOGPT_MEDIUM,
)),
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
Gpt2MergesResources::DIALOGPT_MEDIUM,
)),
min_length: 0,
max_length: 1000,
do_sample: true,
early_stopping: false,
num_beams: 1,
temperature: 1.0,
top_k: 50,
top_p: 0.9,
repetition_penalty: 1.0,
length_penalty: 1.0,
no_repeat_ngram_size: 0,
num_return_sequences: 1,
device: Device::cuda_if_available(),
}
}
}
#[derive(Debug, Clone)]
/// Data structure keeping track of a conversation in the system. It contains past user inputs and
/// generated answers, a history of the tokens generated and a placeholder for new user inputs to be
/// processed by the system if submitted for prediction
pub struct Conversation {
/// Past user inputs that have already been processed
pub past_user_inputs: Vec<String>,
/// Past system generated responses
pub generated_responses: Vec<String>,
/// New user input that needs to be processed
pub new_user_input: Option<String>,
/// History of the tokens passed as an input and generated so far used as context for next turn generation
pub history: Vec<i64>,
}
impl Conversation {
/// Build a new `Conversation` with an initial user input
///
/// # Arguments
///
/// * `text` - `String` with the initial user input to start a conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let conversation = Conversation::new("Hi there!");
/// ```
pub fn new(text: &str) -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: Some(text.to_string()),
history: vec![],
}
}
/// Build a new `Conversation` placeholder without user input
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let conversation = Conversation::new_empty();
/// ```
pub fn new_empty() -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: None,
history: vec![],
}
}
/// Adds a new user input to the conversation. This method returns an error if an unprocessed
/// user input already exists
///
/// # Arguments
///
/// * `text` - `&str` with the additional user input to continue a conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// conversation.add_user_input("Hi there!");
/// ```
pub fn add_user_input(&mut self, text: &str) -> Result<(), &'static str> {
if self.new_user_input.is_some() {
Err("User input already provided for this conversation")
} else {
self.new_user_input = Some(text.to_string());
Ok(())
}
}
/// Adds a new user input to the conversation. If an unprocessed user input already exists,
/// its contents are overwritten by the new value provided.
///
/// # Arguments
///
/// * `text` - `&str` with the additional user input to continue a conversation
///
/// # Returns
///
/// * `Option<String>` containing overwritten string if applicable
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// conversation.add_user_input("This input will not be used");
/// let unused_string = conversation.add_user_input_with_overwrite("Hi there!");
/// ```
pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
let old_user_input = if self.new_user_input.is_some() {
self.new_user_input.clone()
} else {
None
};
self.new_user_input = Some(text.to_string());
old_user_input
}
/// Returns `true` if the conversation contains new user inputs to process
///
/// # Returns
///
/// * `bool` flag indicating if the conversation contains new inputs to process
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let false_value = conversation.contains_new_input();
/// conversation.add_user_input("This input will not be used");
/// let true_value = conversation.contains_new_input();
/// ```
pub fn contains_new_input(&self) -> bool {
self.new_user_input.is_some()
}
/// Marks the conversation as processed and moves the user input that was up for
/// processing to the past user inputs.
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let false_value = conversation.contains_new_input();
/// conversation.add_user_input("This input will not be used");
/// let true_value = conversation.contains_new_input();
/// conversation.mark_processed();
/// let false_value = conversation.contains_new_input();
/// assert_eq!(conversation.past_user_inputs.len(), 1usize);
/// ```
pub fn mark_processed(&mut self) {
if self.new_user_input.is_some() {
self.past_user_inputs
.push(self.new_user_input.clone().unwrap());
self.new_user_input = None;
}
}
/// Returns the last user input provided (including non-processed inputs).
///
/// # Returns
///
/// * `Option<&str>` representation of the last user input provided
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let none_value = conversation.get_last_input();
/// conversation.add_user_input("This input will not be used");
/// let last_provided_input = conversation.get_last_input();
/// assert_eq!(last_provided_input, Some("This input will not be used"));
/// ```
pub fn get_last_input(&self) -> Option<&str> {
if self.new_user_input.is_some() {
Some(self.new_user_input.as_ref().unwrap().as_str())
} else {
if self.past_user_inputs.len() > 0 {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
None
}
}
}
/// Returns the last response generated by the system.
///
/// # Returns
///
/// * `Option<&str>` representation of the last response generated by the system.
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new("Hi There");
/// let non_value = conversation.get_last_response();
/// ```
pub fn get_last_response(&self) -> Option<&str> {
if !self.generated_responses.is_empty() {
Some(self.generated_responses.last().unwrap().as_str())
} else {
None
}
}
}
/// Data structure allowing the management of conversations and main input to the dialogue model.
/// It contains a `HashMap` of conversations with `UUID` keys
#[derive(Debug)]
pub struct ConversationManager {
conversations: HashMap<Uuid, Conversation>,
}
impl ConversationManager {
/// Build a new `ConversationManager`
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::ConversationManager;
///
/// let conversation_manager = ConversationManager::new();
/// ```
pub fn new() -> ConversationManager {
ConversationManager {
conversations: HashMap::new(),
}
}
/// Returns a list of the active conversations (containing new inputs to be processed by the model)
///
/// # Returns
///
/// * `(Vec<&Uuid>, Vec<&mut Conversation>)` Tuple of vectors with the active `UUID` and `Conversations`
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let empty_conversation = Conversation::new_empty();
/// let conversation_id = conversation_manager.add(conversation);
/// let empty_conversation_id = conversation_manager.add(empty_conversation);
///
/// let active_conversations = conversation_manager.get_active_conversations();
/// assert_eq!(active_conversations.0.len(), 1usize);
/// ```
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
let mut active_uuid = vec![];
let mut active_conversations = vec![];
for (uuid, conversation) in self.conversations.iter_mut() {
if conversation.new_user_input.is_some() {
active_uuid.push(uuid);
active_conversations.push(conversation)
}
}
(active_uuid, active_conversations)
}
/// Returns a mutable reference to the conversation wih the provided UUID
///
/// # Arguments
///
/// * `uuid` - `&Uuid` of the conversation to retrieve
///
/// # Returns
///
/// * `Option<&mut Conversation>` Optional mutable reference to the conversation matching the UUID provided
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
///
/// let conversation_ref = conversation_manager.get(&conversation_id);
/// ```
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
self.conversations.get_mut(uuid)
}
/// Returns a HashMap containing references to all conversations stored in the manager
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
///
/// let all_conversations = conversation_manager.get_all();
/// ```
pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
output.insert(uuid, conversation);
}
output
}
/// Creates a conversation and add it to the conversation manager
///
/// # Arguments
///
/// * `text` - `&str` string slice with an original user input
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// ```
pub fn create(&mut self, text: &str) -> Uuid {
let conversation = Conversation::new(text);
self.add(conversation)
}
/// Creates an empty conversation and add it to the conversation manager
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create_empty();
/// ```
pub fn create_empty(&mut self) -> Uuid {
let conversation = Conversation::new_empty();
self.add(conversation)
}
/// Adds an existing conversation to the conversation manager
///
/// # Arguments
///
/// * `conversation` - `Conversation` to be added to the conversation manager
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
/// ```
pub fn add(&mut self, conversation: Conversation) -> Uuid {
let mut uuid = Uuid::new_v4();
while self.conversations.contains_key(&uuid) {
uuid = Uuid::new_v4();
}
self.conversations.insert(uuid, conversation);
uuid
}
/// Deregister a conversation from the conversation manager
///
/// # Arguments
///
/// * `uuid` - `&Uuid` of the conversation to deregister from the conversation manager
///
/// # Returns
///
/// * `Option<Conversation>` deregistered conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// conversation_manager.remove(&conversation_id);
/// ```
pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
self.conversations.remove(uuid)
}
/// Clear all conversations from the conversation manager, and returns the conversations and their
/// former UUID.
///
/// # Returns
///
/// * `HashMap<Uuid, Conversation>` deregistered conversations
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// let conversations = conversation_manager.clear();
/// ```
pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
output.insert(*uuid, conversation.clone());
}
self.conversations = HashMap::new();
output
}
}
/// # Conversation model
/// Processes a ConversationManager and generate system responses for active conversations.
pub struct ConversationModel {
model: GPT2Generator,
eos_token_id: i64,
}
impl ConversationModel {
/// Build a new `ConversationModel`
///
/// # Arguments
///
/// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
///
/// # Example
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::conversation::ConversationModel;
///
/// let conversation_model = ConversationModel::new(Default::default())?;
/// # Ok(())
/// # }
/// ```
pub fn new(conversation_config: ConversationConfig) -> failure::Fallible<ConversationModel> {
let generate_config = GenerateConfig {
model_resource: conversation_config.model_resource,
config_resource: conversation_config.config_resource,
merges_resource: conversation_config.merges_resource,
vocab_resource: conversation_config.vocab_resource,
min_length: conversation_config.min_length,
max_length: conversation_config.max_length,
do_sample: conversation_config.do_sample,
early_stopping: conversation_config.early_stopping,
num_beams: conversation_config.num_beams,
temperature: conversation_config.temperature,
top_k: conversation_config.top_k,
top_p: conversation_config.top_p,
repetition_penalty: conversation_config.repetition_penalty,
length_penalty: conversation_config.length_penalty,
no_repeat_ngram_size: conversation_config.no_repeat_ngram_size,
num_return_sequences: conversation_config.num_return_sequences,
device: conversation_config.device,
};
let model = GPT2Generator::new(generate_config)?;
let eos_token_id = *model.get_eos_ids().as_ref().unwrap().first().unwrap();
Ok(ConversationModel {
model,
eos_token_id,
})
}
/// Perform a multi-turn conversation based on user input
///
/// # Arguments
///
/// * `conversation_manager` - `&mut ConversationManager` Conversation manager keeping track of active conversations
///
/// # Returns
/// * `HashMap<&Uuid, &str>` Responses from the model for each active conversation, referenced by Uuid
///
/// # Example
///
/// ```no_run
/// # fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// let model = ConversationModel::new(Default::default())?;
///
/// let mut conversation_manager = ConversationManager::new();
/// conversation_manager.create("Hello, how are you?");
///
/// let output = model.generate_responses(&mut conversation_manager);
/// # Ok(())
/// # }
/// ```
pub fn generate_responses<'a>(
&self,
conversation_manager: &'a mut ConversationManager,
) -> HashMap<&'a Uuid, &'a str> {
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
if !active_uuid.is_empty() {
let texts = active_conversations
.iter()
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
.collect_vec();
let history = active_conversations
.iter()
.map(|c| &c.history)
.collect_vec();
let prompt_ids = self.encode_prompts(texts.as_slice());
let input_tensor = self.concat_input_history(prompt_ids, history);
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
self.clean_padding_indices(&mut generated);
let mut output = HashMap::with_capacity(active_uuid.len());
for ((conversation, generated_sequence), uuid) in active_conversations
.into_iter()
.zip(generated.into_iter())
.zip(active_uuid.into_iter())
{
conversation
.generated_responses
.push(self.model.get_tokenizer().decode(
generated_sequence[input_length..].to_vec(),
true,
true,
));
conversation.history = generated_sequence;
conversation.mark_processed();
output.insert(uuid, conversation.get_last_response().unwrap());
}
output
} else {
HashMap::new()
}
}
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) {
// In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
let pad_token = match self.model.get_pad_id() {
Some(value) => *value,
None => self.eos_token_id,
};
for sequence_history in model_output {
let index = sequence_history
.iter()
.rev()
.position(|&r| r != pad_token)
.unwrap();
sequence_history.drain(sequence_history.len() - index + 1..);
}
}
fn concat_input_history(&self, inputs: Vec<Vec<i64>>, history: Vec<&Vec<i64>>) -> Tensor {
// Concatenates the history token indices with new user input
let max_len = self.model.get_config().max_length;
let pad_token = match self.model.get_pad_id() {
Some(value) => *value,
None => self.eos_token_id,
};
assert_eq!(
inputs.len(),
history.len(),
"Length of inputs shoudl equal length of history"
);
let mut concatenated_inputs = Vec::with_capacity(inputs.len());
for (input, history) in inputs.iter().zip(history.iter()) {
let mut concatenated_element = Vec::with_capacity(input.len() + history.len());
concatenated_element.extend_from_slice(history);
concatenated_element.extend_from_slice(input);
concatenated_inputs.push(concatenated_element);
}
let num_truncated_tokens = concatenated_inputs
.iter()
.map(|token_ids| {
if token_ids.len() > max_len as usize {
token_ids.len() - max_len as usize
} else {
0
}
})
.collect::<Vec<usize>>();
let concatenated_inputs = concatenated_inputs
.into_iter()
.zip(num_truncated_tokens)
.map(|(tokens, num_truncated_tokens)| {
truncate_sequences(
tokens,
None,
vec![],
None,
vec![],
None,
vec![],
None,
num_truncated_tokens,
&TruncationStrategy::LongestFirst,
0,
)
.unwrap()
.0
})
.collect::<Vec<Vec<i64>>>();
let max_len = concatenated_inputs
.iter()
.map(|input| input.len())
.max()
.unwrap();
let concatenated_inputs = concatenated_inputs
.into_iter()
.map(|input| {
let mut temp = vec![pad_token; max_len - input.len()];
temp.extend(input);
temp
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.model.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&concatenated_inputs, 0)
}
fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
// Encode the user prompt into token ids
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
tokens
.into_iter()
.map(|prompt_tokens| {
self.model
.get_tokenizer()
.convert_tokens_to_ids(&prompt_tokens)
})
.map(|mut tokens| {
tokens.push(self.eos_token_id);
tokens
})
.collect::<Vec<Vec<i64>>>()
}
}

View File

@ -1071,7 +1071,7 @@ pub enum Cache {
None,
}
mod private_generation_utils {
pub(crate) mod private_generation_utils {
use super::ordered_float::OrderedFloat;
use crate::pipelines::generation::{BeamHypotheses, Cache, GenerateConfig, LMHeadModel};
use itertools::Itertools;
@ -1485,7 +1485,12 @@ mod private_generation_utils {
i64::from(sentence_lengths.get(hypothesis_index)),
(Int64, input_ids.device()),
),
&input_ids.get(hypothesis_index),
&input_ids.get(hypothesis_index).slice(
0,
0,
i64::from(sentence_lengths.get(hypothesis_index)),
1,
),
);
}
decoded
@ -1949,24 +1954,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let encoding_max_len = if self.is_encoder_decoder() {
1024u64
} else {
max_length
};
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
let top_p = config.top_p;
let repetition_penalty = config.repetition_penalty;
let length_penalty = config.length_penalty;
let no_repeat_ngram_size = config.no_repeat_ngram_size;
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
@ -1986,6 +1979,42 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
),
},
};
let generated = self.generate_from_ids_and_past(input_ids, attention_mask);
let mut output = Vec::with_capacity(generated.len());
for generated_sequence in generated {
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
}
output
}
fn generate_from_ids_and_past(
&self,
input_ids: Tensor,
attention_mask: Option<Tensor>,
) -> Vec<Vec<i64>> {
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
let config = PrivateLanguageGenerator::get_config(self);
let do_sample = config.do_sample;
let num_return_sequences = config.num_return_sequences;
let num_beams = config.num_beams;
let min_length = config.min_length;
let max_length = config.max_length;
let early_stopping = config.early_stopping;
let temperature = config.temperature;
let top_k = config.top_k;
let top_p = config.top_p;
let repetition_penalty = config.repetition_penalty;
let length_penalty = config.length_penalty;
let no_repeat_ngram_size = config.no_repeat_ngram_size;
let pad_token_id = match self.get_pad_id() {
Some(value) => Some(*value),
None => match &eos_token_ids {
Some(eos_ids) => Some(eos_ids[0]),
None => None,
},
};
let cur_len = if !self.is_encoder_decoder() {
*input_ids.size().last().unwrap()
@ -2055,7 +2084,6 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
);
(input_ids, attention_mask)
};
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(
@ -2099,24 +2127,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
)
}
});
let num_sequences = *decoded.size().first().unwrap();
let mut output = Vec::with_capacity(num_sequences as usize);
let mut output_ids = Vec::with_capacity(num_sequences as usize);
for sequence_index in 0..num_sequences {
output.push(
self.get_tokenizer().decode(
decoded
.as_ref()
.get(sequence_index)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>(),
true,
true,
),
);
let sequence_output_ids = decoded
.as_ref()
.get(sequence_index)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
output_ids.push(sequence_output_ids.clone());
}
output
output_ids
}
}

View File

@ -2,6 +2,9 @@
//!
//! Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
//!
//! **Disclaimer**
//! The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
//!
//! #### 1. Question Answering
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//!
@ -114,7 +117,35 @@
//! ```
//!
//!
//! #### 4. Natural Language Generation
//! #### 4. Dialogue Model
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
//!
//! The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
//!
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
//! let conversation_model = ConversationModel::new(Default::default())?;
//! let mut conversation_manager = ConversationManager::new();
//!
//! let conversation_id =
//! conversation_manager.create("Going to the movies tonight - any suggestions?");
//! let output = conversation_model.generate_responses(&mut conversation_manager);
//! # Ok(())
//! # }
//! ```
//! Example output: \
//! ```no_run
//! # let output =
//! "The Big Lebowski."
//! # ;
//! ```
//!
//! #### 5. Natural Language Generation
//! Generate language based on a prompt. GPT2 and GPT available as base models.
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -145,7 +176,7 @@
//! # ;
//! ```
//!
//! #### 5. Sentiment analysis
//! #### 6. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
@ -184,7 +215,7 @@
//! # ;
//! ```
//!
//! #### 6. Named Entity Recognition
//! #### 7. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
@ -230,6 +261,7 @@
//!
pub mod common;
pub mod conversation;
pub mod generation;
pub mod ner;
pub mod question_answering;

View File

@ -2,6 +2,9 @@ use rust_bert::gpt2::{
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
Gpt2VocabResources,
};
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
use rust_bert::pipelines::generation::{
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
};
@ -299,3 +302,128 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_single_multi_turn_conversation() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
// Set-up conversation manager and add a conversation
let mut conversation_manager = ConversationManager::new();
let conversation_id =
conversation_manager.create("Going to the movies tonight - any suggestions?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_id).unwrap(), &"The Big Lebowski");
// Turn 2
let _ = conversation_manager
.get(&conversation_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_id).unwrap(), &"It\'s a comedy.");
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 0);
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_multiple_multi_turn_conversation() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
// Set-up conversation manager and add a conversation
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =
conversation_manager.create("Going to the movies tonight - any suggestions?");
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
assert_eq!(
output.get(&conversation_2_id).unwrap(),
&"The Last Question"
);
// Turn 2
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 0);
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
// Set-up conversation manager and add a conversation
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =
conversation_manager.create("Going to the movies tonight - any suggestions?");
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
// Turn 1
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 2);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
assert_eq!(
output.get(&conversation_2_id).unwrap(),
&"The Last Question"
);
// Turn 2
let _ = conversation_manager.remove(&conversation_1_id);
let _ = conversation_manager
.get(&conversation_2_id)
.unwrap()
.add_user_input("Why do you recommend it?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(
output.get(&conversation_2_id).unwrap(),
&"It's a good book."
);
// Turn 3 (no new user input)
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 0);
Ok(())
}

View File

@ -0,0 +1,53 @@
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess
ROOT_PATH = S3_BUCKET_PREFIX + '/' + 'microsoft/DialoGPT-medium'
config_path = ROOT_PATH + '/config.json'
vocab_path = ROOT_PATH + '/vocab.json'
merges_path = ROOT_PATH + '/merges.txt'
weights_path = ROOT_PATH + '/pytorch_model.bin'
target_path = Path.home() / 'rustbert' / 'dialogpt-medium'
temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)
os.makedirs(str(target_path), exist_ok=True)
config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.json')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')
shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)
weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
if k == 'wte.weight':
nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))