mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 22:19:05 +03:00
Merge pull request #57 from guillaume-be/multiturn_conversation
Multiturn conversation
This commit is contained in:
commit
d076ec6f77
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "rust-bert"
|
||||
version = "0.7.8"
|
||||
version = "0.7.9"
|
||||
authors = ["Guillaume Becquin <guillaume.becquin@gmail.com>"]
|
||||
edition = "2018"
|
||||
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
|
||||
@ -41,4 +41,5 @@ ordered-float = "1.0.2"
|
||||
csv = "1.1.3"
|
||||
reqwest = "0.10.4"
|
||||
lazy_static = "1.4.0"
|
||||
uuid = {version="0.8.1", features =["v4"]}
|
||||
tokio = { version = "0.2.21", features = ["full"] }
|
34
README.md
34
README.md
@ -25,6 +25,11 @@ Translation | | | | | |✅| |✅ | |
|
||||
## Ready-to-use pipelines
|
||||
|
||||
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
|
||||
|
||||
**Disclaimer**
|
||||
The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
|
||||
|
||||
|
||||
#### 1. Question Answering
|
||||
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
||||
|
||||
@ -106,7 +111,30 @@ This is the first such discovery in a planet in its star's habitable zone.
|
||||
The planet is not too hot and not too cold for liquid water to exist."
|
||||
```
|
||||
|
||||
#### 4. Natural Language Generation
|
||||
#### 4. Dialogue Model
|
||||
Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
|
||||
This pipeline allows the generation of single or multi-turn conversations between a human and a model.
|
||||
The DialoGPT's page states that
|
||||
> The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
|
||||
> under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
|
||||
|
||||
The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
|
||||
|
||||
```rust
|
||||
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
|
||||
|
||||
let conversation_model = ConversationModel::new(Default::default());
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
|
||||
let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
```
|
||||
Example output:
|
||||
```
|
||||
"The Big Lebowski."
|
||||
```
|
||||
|
||||
#### 5. Natural Language Generation
|
||||
Generate language based on a prompt. GPT2 and GPT available as base models.
|
||||
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
|
||||
@ -133,7 +161,7 @@ Example output:
|
||||
]
|
||||
```
|
||||
|
||||
#### 5. Sentiment analysis
|
||||
#### 6. Sentiment analysis
|
||||
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
|
||||
```rust
|
||||
let sentiment_classifier = SentimentModel::new(Default::default())?;
|
||||
@ -157,7 +185,7 @@ Output:
|
||||
]
|
||||
```
|
||||
|
||||
#### 6. Named Entity Recognition
|
||||
#### 7. Named Entity Recognition
|
||||
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
|
||||
```rust
|
||||
let ner_model = NERModel::new(default::default())?;
|
||||
|
43
examples/conversation.rs
Normal file
43
examples/conversation.rs
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
extern crate failure;
|
||||
|
||||
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
let conversation_model = ConversationModel::new(Default::default())?;
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
|
||||
let conversation_1_id =
|
||||
conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
let _conversation_2_id = conversation_manager.create("What's the last book you have read?");
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
|
||||
let _ = conversation_manager
|
||||
.get(&conversation_1_id)
|
||||
.unwrap()
|
||||
.add_user_input("Is it an action movie?");
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
|
||||
println!("{:?}", output);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -276,6 +276,27 @@ fn download_albert_base_v2() -> failure::Fallible<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn _download_dialogpt() -> failure::Fallible<()> {
|
||||
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
||||
));
|
||||
let _ = download_resource(&config_resource)?;
|
||||
let _ = download_resource(&vocab_resource)?;
|
||||
let _ = download_resource(&merges_resource)?;
|
||||
let _ = download_resource(&weights_resource)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
let _ = download_distil_gpt2();
|
||||
let _ = download_distilbert_sst2();
|
||||
|
@ -61,6 +61,11 @@ impl Gpt2ModelResources {
|
||||
"distilgpt2/model.ot",
|
||||
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
|
||||
);
|
||||
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
|
||||
"dialogpt-medium/model.ot",
|
||||
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2ConfigResources {
|
||||
@ -89,6 +94,11 @@ impl Gpt2ConfigResources {
|
||||
"distilgpt2/config.json",
|
||||
"https://cdn.huggingface.co/distilgpt2-config.json",
|
||||
);
|
||||
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
|
||||
"dialogpt-medium/config.json",
|
||||
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2VocabResources {
|
||||
@ -117,6 +127,11 @@ impl Gpt2VocabResources {
|
||||
"distilgpt2/vocab.txt",
|
||||
"https://cdn.huggingface.co/distilgpt2-vocab.json",
|
||||
);
|
||||
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
|
||||
"dialogpt-medium/vocab.txt",
|
||||
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
|
||||
);
|
||||
}
|
||||
|
||||
impl Gpt2MergesResources {
|
||||
@ -145,6 +160,11 @@ impl Gpt2MergesResources {
|
||||
"distilgpt2/merges.txt",
|
||||
"https://cdn.huggingface.co/distilgpt2-merges.txt",
|
||||
);
|
||||
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
|
||||
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
|
||||
"dialogpt-medium/merges.txt",
|
||||
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
|
@ -9,6 +9,7 @@
|
||||
//! - Ready-to-use NLP pipelines for:
|
||||
//! - Translation
|
||||
//! - Summarization
|
||||
//! - Multi-turn dialogue
|
||||
//! - Sentiment Analysis
|
||||
//! - Named Entity Recognition
|
||||
//! - Question-Answering
|
||||
|
797
src/pipelines/conversation.rs
Normal file
797
src/pipelines/conversation.rs
Normal file
@ -0,0 +1,797 @@
|
||||
// Copyright 2019-present Microsoft
|
||||
// Copyright 2020-present, the HuggingFace Inc. team.
|
||||
// Copyright 2020 Guillaume Becquin
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! # Multi-turn dialogue
|
||||
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
|
||||
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
|
||||
//! The DialoGPT's page states that
|
||||
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
|
||||
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
|
||||
//!
|
||||
//!
|
||||
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/dialgpt-medium
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
|
||||
//! let conversation_model = ConversationModel::new(Default::default())?;
|
||||
//! let mut conversation_manager = ConversationManager::new();
|
||||
//!
|
||||
//! let conversation_id =
|
||||
//! conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
//! let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
//!
|
||||
//! Example output: \
|
||||
//! ```no_run
|
||||
//! # let output =
|
||||
//! "The Big Lebowski."
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
//! # Disclaimer
|
||||
//! The authors of this repository are not responsible for any generation
|
||||
//! from the 3rd party utilization of the pretrained system.
|
||||
//!
|
||||
use crate::common::resources::{RemoteResource, Resource};
|
||||
use crate::gpt2::{
|
||||
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
|
||||
};
|
||||
use crate::pipelines::generation::private_generation_utils::PrivateLanguageGenerator;
|
||||
use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenerator};
|
||||
use itertools::Itertools;
|
||||
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
|
||||
use rust_tokenizers::{Tokenizer, TruncationStrategy};
|
||||
use std::collections::HashMap;
|
||||
use tch::{Device, Tensor};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// # Configuration for multi-turn classification
|
||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||
/// different set of default parameters and sets the device to place the model on.
|
||||
pub struct ConversationConfig {
|
||||
/// Model weights resource (default: DialoGPT-medium)
|
||||
pub model_resource: Resource,
|
||||
/// Config resource (default: DialoGPT-medium)
|
||||
pub config_resource: Resource,
|
||||
/// Vocab resource (default: DialoGPT-medium)
|
||||
pub vocab_resource: Resource,
|
||||
/// Merges resource (default: DialoGPT-medium)
|
||||
pub merges_resource: Resource,
|
||||
/// Minimum sequence length (default: 0)
|
||||
pub min_length: u64,
|
||||
/// Maximum sequence length (default: 20)
|
||||
pub max_length: u64,
|
||||
/// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
|
||||
pub do_sample: bool,
|
||||
/// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
|
||||
pub early_stopping: bool,
|
||||
/// Number of beams for beam search (default: 5)
|
||||
pub num_beams: u64,
|
||||
/// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
|
||||
pub temperature: f64,
|
||||
/// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
|
||||
pub top_k: u64,
|
||||
/// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
|
||||
pub top_p: f64,
|
||||
/// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
|
||||
pub repetition_penalty: f64,
|
||||
/// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
|
||||
pub length_penalty: f64,
|
||||
/// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
|
||||
pub no_repeat_ngram_size: u64,
|
||||
/// Number of sequences to return for each prompt text (default: 1)
|
||||
pub num_return_sequences: u64,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl Default for ConversationConfig {
|
||||
fn default() -> ConversationConfig {
|
||||
ConversationConfig {
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ModelResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2ConfigResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2VocabResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
merges_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
Gpt2MergesResources::DIALOGPT_MEDIUM,
|
||||
)),
|
||||
min_length: 0,
|
||||
max_length: 1000,
|
||||
do_sample: true,
|
||||
early_stopping: false,
|
||||
num_beams: 1,
|
||||
temperature: 1.0,
|
||||
top_k: 50,
|
||||
top_p: 0.9,
|
||||
repetition_penalty: 1.0,
|
||||
length_penalty: 1.0,
|
||||
no_repeat_ngram_size: 0,
|
||||
num_return_sequences: 1,
|
||||
device: Device::cuda_if_available(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Data structure keeping track of a conversation in the system. It contains past user inputs and
|
||||
/// generated answers, a history of the tokens generated and a placeholder for new user inputs to be
|
||||
/// processed by the system if submitted for prediction
|
||||
pub struct Conversation {
|
||||
/// Past user inputs that have already been processed
|
||||
pub past_user_inputs: Vec<String>,
|
||||
/// Past system generated responses
|
||||
pub generated_responses: Vec<String>,
|
||||
/// New user input that needs to be processed
|
||||
pub new_user_input: Option<String>,
|
||||
/// History of the tokens passed as an input and generated so far used as context for next turn generation
|
||||
pub history: Vec<i64>,
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
/// Build a new `Conversation` with an initial user input
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `text` - `String` with the initial user input to start a conversation
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let conversation = Conversation::new("Hi there!");
|
||||
/// ```
|
||||
pub fn new(text: &str) -> Conversation {
|
||||
Conversation {
|
||||
past_user_inputs: vec![],
|
||||
generated_responses: vec![],
|
||||
new_user_input: Some(text.to_string()),
|
||||
history: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a new `Conversation` placeholder without user input
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let conversation = Conversation::new_empty();
|
||||
/// ```
|
||||
pub fn new_empty() -> Conversation {
|
||||
Conversation {
|
||||
past_user_inputs: vec![],
|
||||
generated_responses: vec![],
|
||||
new_user_input: None,
|
||||
history: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new user input to the conversation. This method returns an error if an unprocessed
|
||||
/// user input already exists
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `text` - `&str` with the additional user input to continue a conversation
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new_empty();
|
||||
/// conversation.add_user_input("Hi there!");
|
||||
/// ```
|
||||
pub fn add_user_input(&mut self, text: &str) -> Result<(), &'static str> {
|
||||
if self.new_user_input.is_some() {
|
||||
Err("User input already provided for this conversation")
|
||||
} else {
|
||||
self.new_user_input = Some(text.to_string());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a new user input to the conversation. If an unprocessed user input already exists,
|
||||
/// its contents are overwritten by the new value provided.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `text` - `&str` with the additional user input to continue a conversation
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Option<String>` containing overwritten string if applicable
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new_empty();
|
||||
/// conversation.add_user_input("This input will not be used");
|
||||
/// let unused_string = conversation.add_user_input_with_overwrite("Hi there!");
|
||||
/// ```
|
||||
pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
|
||||
let old_user_input = if self.new_user_input.is_some() {
|
||||
self.new_user_input.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.new_user_input = Some(text.to_string());
|
||||
old_user_input
|
||||
}
|
||||
|
||||
/// Returns `true` if the conversation contains new user inputs to process
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `bool` flag indicating if the conversation contains new inputs to process
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new_empty();
|
||||
/// let false_value = conversation.contains_new_input();
|
||||
/// conversation.add_user_input("This input will not be used");
|
||||
/// let true_value = conversation.contains_new_input();
|
||||
/// ```
|
||||
pub fn contains_new_input(&self) -> bool {
|
||||
self.new_user_input.is_some()
|
||||
}
|
||||
|
||||
/// Marks the conversation as processed and moves the user input that was up for
|
||||
/// processing to the past user inputs.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new_empty();
|
||||
/// let false_value = conversation.contains_new_input();
|
||||
/// conversation.add_user_input("This input will not be used");
|
||||
/// let true_value = conversation.contains_new_input();
|
||||
/// conversation.mark_processed();
|
||||
/// let false_value = conversation.contains_new_input();
|
||||
/// assert_eq!(conversation.past_user_inputs.len(), 1usize);
|
||||
/// ```
|
||||
pub fn mark_processed(&mut self) {
|
||||
if self.new_user_input.is_some() {
|
||||
self.past_user_inputs
|
||||
.push(self.new_user_input.clone().unwrap());
|
||||
self.new_user_input = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the last user input provided (including non-processed inputs).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Option<&str>` representation of the last user input provided
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new_empty();
|
||||
/// let none_value = conversation.get_last_input();
|
||||
/// conversation.add_user_input("This input will not be used");
|
||||
/// let last_provided_input = conversation.get_last_input();
|
||||
/// assert_eq!(last_provided_input, Some("This input will not be used"));
|
||||
/// ```
|
||||
pub fn get_last_input(&self) -> Option<&str> {
|
||||
if self.new_user_input.is_some() {
|
||||
Some(self.new_user_input.as_ref().unwrap().as_str())
|
||||
} else {
|
||||
if self.past_user_inputs.len() > 0 {
|
||||
Some(self.past_user_inputs.last().unwrap().as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the last response generated by the system.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Option<&str>` representation of the last response generated by the system.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::Conversation;
|
||||
///
|
||||
/// let mut conversation = Conversation::new("Hi There");
|
||||
/// let non_value = conversation.get_last_response();
|
||||
/// ```
|
||||
pub fn get_last_response(&self) -> Option<&str> {
|
||||
if !self.generated_responses.is_empty() {
|
||||
Some(self.generated_responses.last().unwrap().as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Data structure allowing the management of conversations and main input to the dialogue model.
|
||||
/// It contains a `HashMap` of conversations with `UUID` keys
|
||||
#[derive(Debug)]
|
||||
pub struct ConversationManager {
|
||||
conversations: HashMap<Uuid, Conversation>,
|
||||
}
|
||||
|
||||
impl ConversationManager {
|
||||
/// Build a new `ConversationManager`
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::ConversationManager;
|
||||
///
|
||||
/// let conversation_manager = ConversationManager::new();
|
||||
/// ```
|
||||
pub fn new() -> ConversationManager {
|
||||
ConversationManager {
|
||||
conversations: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a list of the active conversations (containing new inputs to be processed by the model)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `(Vec<&Uuid>, Vec<&mut Conversation>)` Tuple of vectors with the active `UUID` and `Conversations`
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation = Conversation::new("Hi there!");
|
||||
/// let empty_conversation = Conversation::new_empty();
|
||||
/// let conversation_id = conversation_manager.add(conversation);
|
||||
/// let empty_conversation_id = conversation_manager.add(empty_conversation);
|
||||
///
|
||||
/// let active_conversations = conversation_manager.get_active_conversations();
|
||||
/// assert_eq!(active_conversations.0.len(), 1usize);
|
||||
/// ```
|
||||
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
|
||||
let mut active_uuid = vec![];
|
||||
let mut active_conversations = vec![];
|
||||
for (uuid, conversation) in self.conversations.iter_mut() {
|
||||
if conversation.new_user_input.is_some() {
|
||||
active_uuid.push(uuid);
|
||||
active_conversations.push(conversation)
|
||||
}
|
||||
}
|
||||
(active_uuid, active_conversations)
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the conversation wih the provided UUID
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `uuid` - `&Uuid` of the conversation to retrieve
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Option<&mut Conversation>` Optional mutable reference to the conversation matching the UUID provided
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation = Conversation::new("Hi there!");
|
||||
/// let conversation_id = conversation_manager.add(conversation);
|
||||
///
|
||||
/// let conversation_ref = conversation_manager.get(&conversation_id);
|
||||
/// ```
|
||||
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
|
||||
self.conversations.get_mut(uuid)
|
||||
}
|
||||
|
||||
/// Returns a HashMap containing references to all conversations stored in the manager
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation = Conversation::new("Hi there!");
|
||||
/// let conversation_id = conversation_manager.add(conversation);
|
||||
///
|
||||
/// let all_conversations = conversation_manager.get_all();
|
||||
/// ```
|
||||
pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
|
||||
let mut output = HashMap::with_capacity(self.conversations.len());
|
||||
for (uuid, conversation) in self.conversations.iter() {
|
||||
output.insert(uuid, conversation);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// Creates a conversation and add it to the conversation manager
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `text` - `&str` string slice with an original user input
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Uuid` for the conversation created
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation_id = conversation_manager.create("Hi there!");
|
||||
/// ```
|
||||
pub fn create(&mut self, text: &str) -> Uuid {
|
||||
let conversation = Conversation::new(text);
|
||||
self.add(conversation)
|
||||
}
|
||||
|
||||
/// Creates an empty conversation and add it to the conversation manager
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Uuid` for the conversation created
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation_id = conversation_manager.create_empty();
|
||||
/// ```
|
||||
pub fn create_empty(&mut self) -> Uuid {
|
||||
let conversation = Conversation::new_empty();
|
||||
self.add(conversation)
|
||||
}
|
||||
|
||||
/// Adds an existing conversation to the conversation manager
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `conversation` - `Conversation` to be added to the conversation manager
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Uuid` for the conversation created
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation = Conversation::new("Hi there!");
|
||||
/// let conversation_id = conversation_manager.add(conversation);
|
||||
/// ```
|
||||
pub fn add(&mut self, conversation: Conversation) -> Uuid {
|
||||
let mut uuid = Uuid::new_v4();
|
||||
while self.conversations.contains_key(&uuid) {
|
||||
uuid = Uuid::new_v4();
|
||||
}
|
||||
self.conversations.insert(uuid, conversation);
|
||||
uuid
|
||||
}
|
||||
|
||||
/// Deregister a conversation from the conversation manager
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `uuid` - `&Uuid` of the conversation to deregister from the conversation manager
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `Option<Conversation>` deregistered conversation
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation_id = conversation_manager.create("Hi there!");
|
||||
/// conversation_manager.remove(&conversation_id);
|
||||
/// ```
|
||||
pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
|
||||
self.conversations.remove(uuid)
|
||||
}
|
||||
|
||||
/// Clear all conversations from the conversation manager, and returns the conversations and their
|
||||
/// former UUID.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `HashMap<Uuid, Conversation>` deregistered conversations
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use rust_bert::pipelines::conversation::{Conversation, ConversationManager};
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
///
|
||||
/// let conversation_id = conversation_manager.create("Hi there!");
|
||||
/// let conversations = conversation_manager.clear();
|
||||
/// ```
|
||||
pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
|
||||
let mut output = HashMap::with_capacity(self.conversations.len());
|
||||
for (uuid, conversation) in self.conversations.iter() {
|
||||
output.insert(*uuid, conversation.clone());
|
||||
}
|
||||
self.conversations = HashMap::new();
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// # Conversation model
|
||||
/// Processes a ConversationManager and generate system responses for active conversations.
|
||||
pub struct ConversationModel {
|
||||
model: GPT2Generator,
|
||||
eos_token_id: i64,
|
||||
}
|
||||
|
||||
impl ConversationModel {
|
||||
/// Build a new `ConversationModel`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `conversation_config` - `ConversationConfig` object containing the resource references (model, vocabulary, configuration), conversation options and device placement (CPU/GPU)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::conversation::ConversationModel;
|
||||
///
|
||||
/// let conversation_model = ConversationModel::new(Default::default())?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(conversation_config: ConversationConfig) -> failure::Fallible<ConversationModel> {
|
||||
let generate_config = GenerateConfig {
|
||||
model_resource: conversation_config.model_resource,
|
||||
config_resource: conversation_config.config_resource,
|
||||
merges_resource: conversation_config.merges_resource,
|
||||
vocab_resource: conversation_config.vocab_resource,
|
||||
min_length: conversation_config.min_length,
|
||||
max_length: conversation_config.max_length,
|
||||
do_sample: conversation_config.do_sample,
|
||||
early_stopping: conversation_config.early_stopping,
|
||||
num_beams: conversation_config.num_beams,
|
||||
temperature: conversation_config.temperature,
|
||||
top_k: conversation_config.top_k,
|
||||
top_p: conversation_config.top_p,
|
||||
repetition_penalty: conversation_config.repetition_penalty,
|
||||
length_penalty: conversation_config.length_penalty,
|
||||
no_repeat_ngram_size: conversation_config.no_repeat_ngram_size,
|
||||
num_return_sequences: conversation_config.num_return_sequences,
|
||||
device: conversation_config.device,
|
||||
};
|
||||
|
||||
let model = GPT2Generator::new(generate_config)?;
|
||||
let eos_token_id = *model.get_eos_ids().as_ref().unwrap().first().unwrap();
|
||||
Ok(ConversationModel {
|
||||
model,
|
||||
eos_token_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Perform a multi-turn conversation based on user input
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `conversation_manager` - `&mut ConversationManager` Conversation manager keeping track of active conversations
|
||||
///
|
||||
/// # Returns
|
||||
/// * `HashMap<&Uuid, &str>` Responses from the model for each active conversation, referenced by Uuid
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # fn main() -> failure::Fallible<()> {
|
||||
/// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
|
||||
/// use rust_bert::pipelines::generation::LanguageGenerator;
|
||||
/// let model = ConversationModel::new(Default::default())?;
|
||||
///
|
||||
/// let mut conversation_manager = ConversationManager::new();
|
||||
/// conversation_manager.create("Hello, how are you?");
|
||||
///
|
||||
/// let output = model.generate_responses(&mut conversation_manager);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn generate_responses<'a>(
|
||||
&self,
|
||||
conversation_manager: &'a mut ConversationManager,
|
||||
) -> HashMap<&'a Uuid, &'a str> {
|
||||
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
|
||||
if !active_uuid.is_empty() {
|
||||
let texts = active_conversations
|
||||
.iter()
|
||||
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
|
||||
.collect_vec();
|
||||
|
||||
let history = active_conversations
|
||||
.iter()
|
||||
.map(|c| &c.history)
|
||||
.collect_vec();
|
||||
|
||||
let prompt_ids = self.encode_prompts(texts.as_slice());
|
||||
let input_tensor = self.concat_input_history(prompt_ids, history);
|
||||
let input_length = *input_tensor.size().last().unwrap() as usize;
|
||||
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
|
||||
self.clean_padding_indices(&mut generated);
|
||||
|
||||
let mut output = HashMap::with_capacity(active_uuid.len());
|
||||
|
||||
for ((conversation, generated_sequence), uuid) in active_conversations
|
||||
.into_iter()
|
||||
.zip(generated.into_iter())
|
||||
.zip(active_uuid.into_iter())
|
||||
{
|
||||
conversation
|
||||
.generated_responses
|
||||
.push(self.model.get_tokenizer().decode(
|
||||
generated_sequence[input_length..].to_vec(),
|
||||
true,
|
||||
true,
|
||||
));
|
||||
conversation.history = generated_sequence;
|
||||
conversation.mark_processed();
|
||||
output.insert(uuid, conversation.get_last_response().unwrap());
|
||||
}
|
||||
output
|
||||
} else {
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
|
||||
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) {
|
||||
// In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
|
||||
let pad_token = match self.model.get_pad_id() {
|
||||
Some(value) => *value,
|
||||
None => self.eos_token_id,
|
||||
};
|
||||
for sequence_history in model_output {
|
||||
let index = sequence_history
|
||||
.iter()
|
||||
.rev()
|
||||
.position(|&r| r != pad_token)
|
||||
.unwrap();
|
||||
sequence_history.drain(sequence_history.len() - index + 1..);
|
||||
}
|
||||
}
|
||||
|
||||
fn concat_input_history(&self, inputs: Vec<Vec<i64>>, history: Vec<&Vec<i64>>) -> Tensor {
|
||||
// Concatenates the history token indices with new user input
|
||||
let max_len = self.model.get_config().max_length;
|
||||
let pad_token = match self.model.get_pad_id() {
|
||||
Some(value) => *value,
|
||||
None => self.eos_token_id,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
inputs.len(),
|
||||
history.len(),
|
||||
"Length of inputs shoudl equal length of history"
|
||||
);
|
||||
|
||||
let mut concatenated_inputs = Vec::with_capacity(inputs.len());
|
||||
for (input, history) in inputs.iter().zip(history.iter()) {
|
||||
let mut concatenated_element = Vec::with_capacity(input.len() + history.len());
|
||||
concatenated_element.extend_from_slice(history);
|
||||
concatenated_element.extend_from_slice(input);
|
||||
concatenated_inputs.push(concatenated_element);
|
||||
}
|
||||
|
||||
let num_truncated_tokens = concatenated_inputs
|
||||
.iter()
|
||||
.map(|token_ids| {
|
||||
if token_ids.len() > max_len as usize {
|
||||
token_ids.len() - max_len as usize
|
||||
} else {
|
||||
0
|
||||
}
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
|
||||
let concatenated_inputs = concatenated_inputs
|
||||
.into_iter()
|
||||
.zip(num_truncated_tokens)
|
||||
.map(|(tokens, num_truncated_tokens)| {
|
||||
truncate_sequences(
|
||||
tokens,
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
vec![],
|
||||
None,
|
||||
num_truncated_tokens,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
)
|
||||
.unwrap()
|
||||
.0
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>();
|
||||
|
||||
let max_len = concatenated_inputs
|
||||
.iter()
|
||||
.map(|input| input.len())
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
let concatenated_inputs = concatenated_inputs
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
let mut temp = vec![pad_token; max_len - input.len()];
|
||||
temp.extend(input);
|
||||
temp
|
||||
})
|
||||
.map(|tokens| Tensor::of_slice(&tokens).to(self.model.get_var_store().device()))
|
||||
.collect::<Vec<Tensor>>();
|
||||
|
||||
Tensor::stack(&concatenated_inputs, 0)
|
||||
}
|
||||
|
||||
fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
|
||||
// Encode the user prompt into token ids
|
||||
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
|
||||
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|prompt_tokens| {
|
||||
self.model
|
||||
.get_tokenizer()
|
||||
.convert_tokens_to_ids(&prompt_tokens)
|
||||
})
|
||||
.map(|mut tokens| {
|
||||
tokens.push(self.eos_token_id);
|
||||
tokens
|
||||
})
|
||||
.collect::<Vec<Vec<i64>>>()
|
||||
}
|
||||
}
|
@ -1071,7 +1071,7 @@ pub enum Cache {
|
||||
None,
|
||||
}
|
||||
|
||||
mod private_generation_utils {
|
||||
pub(crate) mod private_generation_utils {
|
||||
use super::ordered_float::OrderedFloat;
|
||||
use crate::pipelines::generation::{BeamHypotheses, Cache, GenerateConfig, LMHeadModel};
|
||||
use itertools::Itertools;
|
||||
@ -1485,7 +1485,12 @@ mod private_generation_utils {
|
||||
i64::from(sentence_lengths.get(hypothesis_index)),
|
||||
(Int64, input_ids.device()),
|
||||
),
|
||||
&input_ids.get(hypothesis_index),
|
||||
&input_ids.get(hypothesis_index).slice(
|
||||
0,
|
||||
0,
|
||||
i64::from(sentence_lengths.get(hypothesis_index)),
|
||||
1,
|
||||
),
|
||||
);
|
||||
}
|
||||
decoded
|
||||
@ -1949,24 +1954,12 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
||||
|
||||
let config = PrivateLanguageGenerator::get_config(self);
|
||||
let do_sample = config.do_sample;
|
||||
let num_return_sequences = config.num_return_sequences;
|
||||
let num_beams = config.num_beams;
|
||||
let min_length = config.min_length;
|
||||
let max_length = config.max_length;
|
||||
let encoding_max_len = if self.is_encoder_decoder() {
|
||||
1024u64
|
||||
} else {
|
||||
max_length
|
||||
};
|
||||
let early_stopping = config.early_stopping;
|
||||
let temperature = config.temperature;
|
||||
let top_k = config.top_k;
|
||||
let top_p = config.top_p;
|
||||
let repetition_penalty = config.repetition_penalty;
|
||||
let length_penalty = config.length_penalty;
|
||||
let no_repeat_ngram_size = config.no_repeat_ngram_size;
|
||||
|
||||
let pad_token_id = match self.get_pad_id() {
|
||||
Some(value) => Some(*value),
|
||||
None => match &eos_token_ids {
|
||||
@ -1986,6 +1979,42 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
),
|
||||
},
|
||||
};
|
||||
let generated = self.generate_from_ids_and_past(input_ids, attention_mask);
|
||||
let mut output = Vec::with_capacity(generated.len());
|
||||
for generated_sequence in generated {
|
||||
output.push(self.get_tokenizer().decode(generated_sequence, true, true));
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
fn generate_from_ids_and_past(
|
||||
&self,
|
||||
input_ids: Tensor,
|
||||
attention_mask: Option<Tensor>,
|
||||
) -> Vec<Vec<i64>> {
|
||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
||||
|
||||
let config = PrivateLanguageGenerator::get_config(self);
|
||||
let do_sample = config.do_sample;
|
||||
let num_return_sequences = config.num_return_sequences;
|
||||
let num_beams = config.num_beams;
|
||||
let min_length = config.min_length;
|
||||
let max_length = config.max_length;
|
||||
let early_stopping = config.early_stopping;
|
||||
let temperature = config.temperature;
|
||||
let top_k = config.top_k;
|
||||
let top_p = config.top_p;
|
||||
let repetition_penalty = config.repetition_penalty;
|
||||
let length_penalty = config.length_penalty;
|
||||
let no_repeat_ngram_size = config.no_repeat_ngram_size;
|
||||
|
||||
let pad_token_id = match self.get_pad_id() {
|
||||
Some(value) => Some(*value),
|
||||
None => match &eos_token_ids {
|
||||
Some(eos_ids) => Some(eos_ids[0]),
|
||||
None => None,
|
||||
},
|
||||
};
|
||||
|
||||
let cur_len = if !self.is_encoder_decoder() {
|
||||
*input_ids.size().last().unwrap()
|
||||
@ -2055,7 +2084,6 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
);
|
||||
(input_ids, attention_mask)
|
||||
};
|
||||
|
||||
let decoded = no_grad(|| {
|
||||
if num_beams > 1 {
|
||||
self.generate_beam_search(
|
||||
@ -2099,24 +2127,18 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
let num_sequences = *decoded.size().first().unwrap();
|
||||
let mut output = Vec::with_capacity(num_sequences as usize);
|
||||
let mut output_ids = Vec::with_capacity(num_sequences as usize);
|
||||
for sequence_index in 0..num_sequences {
|
||||
output.push(
|
||||
self.get_tokenizer().decode(
|
||||
decoded
|
||||
.as_ref()
|
||||
.get(sequence_index)
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>(),
|
||||
true,
|
||||
true,
|
||||
),
|
||||
);
|
||||
let sequence_output_ids = decoded
|
||||
.as_ref()
|
||||
.get(sequence_index)
|
||||
.iter::<i64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<i64>>();
|
||||
output_ids.push(sequence_output_ids.clone());
|
||||
}
|
||||
output
|
||||
output_ids
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,9 @@
|
||||
//!
|
||||
//! Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
|
||||
//!
|
||||
//! **Disclaimer**
|
||||
//! The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
|
||||
//!
|
||||
//! #### 1. Question Answering
|
||||
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
|
||||
//!
|
||||
@ -114,7 +117,35 @@
|
||||
//! ```
|
||||
//!
|
||||
//!
|
||||
//! #### 4. Natural Language Generation
|
||||
//! #### 4. Dialogue Model
|
||||
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
|
||||
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
|
||||
//! The DialoGPT's page states that
|
||||
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
|
||||
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
|
||||
//!
|
||||
//! The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
|
||||
//!
|
||||
//! ```no_run
|
||||
//! # fn main() -> failure::Fallible<()> {
|
||||
//! use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
|
||||
//! let conversation_model = ConversationModel::new(Default::default())?;
|
||||
//! let mut conversation_manager = ConversationManager::new();
|
||||
//!
|
||||
//! let conversation_id =
|
||||
//! conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
//! let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
//! Example output: \
|
||||
//! ```no_run
|
||||
//! # let output =
|
||||
//! "The Big Lebowski."
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
//! #### 5. Natural Language Generation
|
||||
//! Generate language based on a prompt. GPT2 and GPT available as base models.
|
||||
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
|
||||
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
|
||||
@ -145,7 +176,7 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
//! #### 5. Sentiment analysis
|
||||
//! #### 6. Sentiment analysis
|
||||
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::sentiment::SentimentModel;
|
||||
@ -184,7 +215,7 @@
|
||||
//! # ;
|
||||
//! ```
|
||||
//!
|
||||
//! #### 6. Named Entity Recognition
|
||||
//! #### 7. Named Entity Recognition
|
||||
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
|
||||
//! ```no_run
|
||||
//! use rust_bert::pipelines::ner::NERModel;
|
||||
@ -230,6 +261,7 @@
|
||||
//!
|
||||
|
||||
pub mod common;
|
||||
pub mod conversation;
|
||||
pub mod generation;
|
||||
pub mod ner;
|
||||
pub mod question_answering;
|
||||
|
128
tests/gpt2.rs
128
tests/gpt2.rs
@ -2,6 +2,9 @@ use rust_bert::gpt2::{
|
||||
GPT2LMHeadModel, Gpt2Config, Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources,
|
||||
Gpt2VocabResources,
|
||||
};
|
||||
use rust_bert::pipelines::conversation::{
|
||||
ConversationConfig, ConversationManager, ConversationModel,
|
||||
};
|
||||
use rust_bert::pipelines::generation::{
|
||||
Cache, GPT2Generator, GenerateConfig, LMHeadModel, LanguageGenerator,
|
||||
};
|
||||
@ -299,3 +302,128 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn dialogpt_single_multi_turn_conversation() -> failure::Fallible<()> {
|
||||
// Set-up conversation model
|
||||
let conversation_config = ConversationConfig {
|
||||
do_sample: false,
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
let conversation_model = ConversationModel::new(conversation_config)?;
|
||||
|
||||
// Set-up conversation manager and add a conversation
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
let conversation_id =
|
||||
conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_id).unwrap(), &"The Big Lebowski");
|
||||
|
||||
// Turn 2
|
||||
let _ = conversation_manager
|
||||
.get(&conversation_id)
|
||||
.unwrap()
|
||||
.add_user_input("Is it an action movie?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_id).unwrap(), &"It\'s a comedy.");
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn dialogpt_multiple_multi_turn_conversation() -> failure::Fallible<()> {
|
||||
// Set-up conversation model
|
||||
let conversation_config = ConversationConfig {
|
||||
do_sample: false,
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
let conversation_model = ConversationModel::new(conversation_config)?;
|
||||
|
||||
// Set-up conversation manager and add a conversation
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
let conversation_1_id =
|
||||
conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
|
||||
assert_eq!(
|
||||
output.get(&conversation_2_id).unwrap(),
|
||||
&"The Last Question"
|
||||
);
|
||||
|
||||
// Turn 2
|
||||
let _ = conversation_manager
|
||||
.get(&conversation_1_id)
|
||||
.unwrap()
|
||||
.add_user_input("Is it an action movie?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(not(feature = "all-tests"), ignore)]
|
||||
fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> failure::Fallible<()> {
|
||||
// Set-up conversation model
|
||||
let conversation_config = ConversationConfig {
|
||||
do_sample: false,
|
||||
device: Device::Cpu,
|
||||
..Default::default()
|
||||
};
|
||||
let conversation_model = ConversationModel::new(conversation_config)?;
|
||||
|
||||
// Set-up conversation manager and add a conversation
|
||||
let mut conversation_manager = ConversationManager::new();
|
||||
let conversation_1_id =
|
||||
conversation_manager.create("Going to the movies tonight - any suggestions?");
|
||||
let conversation_2_id = conversation_manager.create("What's the last book you have read?");
|
||||
|
||||
// Turn 1
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 2);
|
||||
assert_eq!(output.get(&conversation_1_id).unwrap(), &"The Big Lebowski");
|
||||
assert_eq!(
|
||||
output.get(&conversation_2_id).unwrap(),
|
||||
&"The Last Question"
|
||||
);
|
||||
|
||||
// Turn 2
|
||||
let _ = conversation_manager.remove(&conversation_1_id);
|
||||
let _ = conversation_manager
|
||||
.get(&conversation_2_id)
|
||||
.unwrap()
|
||||
.add_user_input("Why do you recommend it?");
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
output.get(&conversation_2_id).unwrap(),
|
||||
&"It's a good book."
|
||||
);
|
||||
|
||||
// Turn 3 (no new user input)
|
||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||
assert_eq!(output.len(), 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
53
utils/download-dependencies_dialogpt-medium.py
Normal file
53
utils/download-dependencies_dialogpt-medium.py
Normal file
@ -0,0 +1,53 @@
|
||||
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import subprocess
|
||||
|
||||
ROOT_PATH = S3_BUCKET_PREFIX + '/' + 'microsoft/DialoGPT-medium'
|
||||
|
||||
config_path = ROOT_PATH + '/config.json'
|
||||
vocab_path = ROOT_PATH + '/vocab.json'
|
||||
merges_path = ROOT_PATH + '/merges.txt'
|
||||
weights_path = ROOT_PATH + '/pytorch_model.bin'
|
||||
|
||||
target_path = Path.home() / 'rustbert' / 'dialogpt-medium'
|
||||
|
||||
temp_config = get_from_cache(config_path)
|
||||
temp_vocab = get_from_cache(vocab_path)
|
||||
temp_merges = get_from_cache(merges_path)
|
||||
temp_weights = get_from_cache(weights_path)
|
||||
|
||||
os.makedirs(str(target_path), exist_ok=True)
|
||||
|
||||
config_path = str(target_path / 'config.json')
|
||||
vocab_path = str(target_path / 'vocab.json')
|
||||
merges_path = str(target_path / 'merges.txt')
|
||||
model_path = str(target_path / 'model.bin')
|
||||
|
||||
shutil.copy(temp_config, config_path)
|
||||
shutil.copy(temp_vocab, vocab_path)
|
||||
shutil.copy(temp_merges, merges_path)
|
||||
shutil.copy(temp_weights, model_path)
|
||||
|
||||
weights = torch.load(temp_weights, map_location='cpu')
|
||||
nps = {}
|
||||
for k, v in weights.items():
|
||||
nps[k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
|
||||
if k == 'wte.weight':
|
||||
nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
|
||||
|
||||
np.savez(target_path / 'model.npz', **nps)
|
||||
|
||||
source = str(target_path / 'model.npz')
|
||||
target = str(target_path / 'model.ot')
|
||||
|
||||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
|
||||
|
||||
subprocess.call(
|
||||
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
|
||||
|
||||
os.remove(str(target_path / 'model.bin'))
|
||||
os.remove(str(target_path / 'model.npz'))
|
Loading…
Reference in New Issue
Block a user