Added documentation, updated tests & doctests

This commit is contained in:
Guillaume B 2020-06-28 11:04:36 +02:00
parent 68def5a912
commit 23310dfc1e
7 changed files with 444 additions and 44 deletions

View File

@ -25,6 +25,11 @@ Translation | | | | | |✅| |✅ | |
## Ready-to-use pipelines
Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
**Disclaimer**
The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
#### 1. Question Answering
Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
@ -106,7 +111,30 @@ This is the first such discovery in a planet in its star's habitable zone.
The planet is not too hot and not too cold for liquid water to exist."
```
#### 4. Natural Language Generation
#### 4. Dialogue Model
Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
This pipeline allows the generation of single or multi-turn conversations between a human and a model.
The DialoGPT's page states that
> The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
> under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
```rust
use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
let conversation_model = ConversationModel::new(Default::default());
let mut conversation_manager = ConversationManager::new();
let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
let output = conversation_model.generate_responses(&mut conversation_manager);
```
Example output:
```
"The Big Lebowski."
```
#### 5. Natural Language Generation
Generate language based on a prompt. GPT2 and GPT available as base models.
Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -133,7 +161,7 @@ Example output:
]
```
#### 5. Sentiment analysis
#### 6. Sentiment analysis
Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
```rust
let sentiment_classifier = SentimentModel::new(Default::default())?;
@ -157,7 +185,7 @@ Output:
]
```
#### 6. Named Entity Recognition
#### 7. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
```rust
let ner_model = NERModel::new(default::default())?;

View File

@ -12,17 +12,10 @@
extern crate failure;
use rust_bert::pipelines::conversation::{
ConversationConfig, ConversationManager, ConversationModel,
};
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
fn main() -> failure::Fallible<()> {
let conversation_config = ConversationConfig {
do_sample: false,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
let conversation_model = ConversationModel::new(Default::default())?;
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =
@ -36,7 +29,7 @@ fn main() -> failure::Fallible<()> {
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input(String::from("Is it an action movie?"));
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);

View File

@ -276,7 +276,7 @@ fn download_albert_base_v2() -> failure::Fallible<()> {
Ok(())
}
fn download_dialogpt() -> failure::Fallible<()> {
fn _download_dialogpt() -> failure::Fallible<()> {
// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
Gpt2ConfigResources::DIALOGPT_MEDIUM,
@ -312,7 +312,6 @@ fn main() -> failure::Fallible<()> {
let _ = download_electra_generator();
let _ = download_electra_discriminator();
let _ = download_albert_base_v2();
let _ = download_dialogpt();
Ok(())
}

View File

@ -9,6 +9,7 @@
//! - Ready-to-use NLP pipelines for:
//! - Translation
//! - Summarization
//! - Multi-turn dialogue
//! - Sentiment Analysis
//! - Named Entity Recognition
//! - Question-Answering

View File

@ -11,18 +11,39 @@
// See the License for the specific language governing permissions and
// limitations under the License.
/// # Disclaimer
/// This repository aims to facilitate research in large-scale pre-training for conversational data.
/// This toolkit contains only part of the modeling machinery needed to actually produce a model
/// weight file in a running dialog. On its own, this model provides only information about the
/// weights of various text spans; in order for a researcher to actually use it, they will need
/// to bring conversational data of their own and decode the response generation from the pretrained
/// system. Neither the author of this repository or Microsoft are responsible for any generation
/// from the 3rd party utilization of the pretrained system.
///
///
///
///
//! # Multi-turn dialogue
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
//!
//!
//! The dependencies will be downloaded to the user's home directory, under ~/.cache/.rustbert/dialgpt-medium
//!
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
//! let conversation_model = ConversationModel::new(Default::default())?;
//! let mut conversation_manager = ConversationManager::new();
//!
//! let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
//! let output = conversation_model.generate_responses(&mut conversation_manager);
//! # Ok(())
//! # }
//! ```
//!
//! Example output: \
//! ```no_run
//! # let output =
//! "The Big Lebowski."
//! # ;
//! ```
//!
//! # Disclaimer
//! The authors of this repository are not responsible for any generation
//! from the 3rd party utilization of the pretrained system.
//!
use crate::common::resources::{RemoteResource, Resource};
use crate::gpt2::{
Gpt2ConfigResources, Gpt2MergesResources, Gpt2ModelResources, Gpt2VocabResources,
@ -108,37 +129,152 @@ impl Default for ConversationConfig {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
/// Data structure keeping track of a conversation in the system. It contains past user inputs and
/// generated answers, a history of the tokens generated and a placeholder for new user inputs to be
/// processed by the system if submitted for prediction
pub struct Conversation {
/// Past user inputs that have already been processed
pub past_user_inputs: Vec<String>,
/// Past system generated responses
pub generated_responses: Vec<String>,
/// New user input that needs to be processed
pub new_user_input: Option<String>,
/// History of the tokens passed as an input and generated so far used as context for next turn generation
pub history: Vec<i64>,
}
impl Conversation {
pub fn new(text: String) -> Conversation {
/// Build a new `Conversation` with an initial user input
///
/// # Arguments
///
/// * `text` - `String` with the initial user input to start a conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let conversation = Conversation::new("Hi there!");
/// ```
pub fn new(text: &str) -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: Some(text),
new_user_input: Some(text.to_string()),
history: vec![],
}
}
pub fn add_user_input(&mut self, text: String) -> Result<(), &'static str> {
/// Build a new `Conversation` placeholder without user input
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let conversation = Conversation::new_empty();
/// ```
pub fn new_empty() -> Conversation {
Conversation {
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: None,
history: vec![],
}
}
/// Adds a new user input to the conversation. This method returns an error if an unprocessed
/// user input already exists
///
/// # Arguments
///
/// * `text` - `&str` with the additional user input to continue a conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// conversation.add_user_input("Hi there!");
/// ```
pub fn add_user_input(&mut self, text: &str) -> Result<(), &'static str> {
if self.new_user_input.is_some() {
Err("User input already provided for this conversation")
} else {
self.new_user_input = Some(text);
self.new_user_input = Some(text.to_string());
Ok(())
}
}
/// Adds a new user input to the conversation. If an unprocessed user input already exists,
/// its contents are overwritten by the new value provided.
///
/// # Arguments
///
/// * `text` - `&str` with the additional user input to continue a conversation
///
/// # Returns
///
/// * `Option<String>` containing overwritten string if applicable
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// conversation.add_user_input("This input will not be used");
/// let unused_string = conversation.add_user_input_with_overwrite("Hi there!");
/// ```
pub fn add_user_input_with_overwrite(&mut self, text: &str) -> Option<String> {
let old_user_input = if self.new_user_input.is_some() {
self.new_user_input.clone()
} else {
None
};
self.new_user_input = Some(text.to_string());
old_user_input
}
/// Returns `true` if the conversation contains new user inputs to process
///
/// # Returns
///
/// * `bool` flag indicating if the conversation contains new inputs to process
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let false_value = conversation.contains_new_input();
/// conversation.add_user_input("This input will not be used");
/// let true_value = conversation.contains_new_input();
/// ```
pub fn contains_new_input(&self) -> bool {
self.new_user_input.is_some()
}
/// Marks the conversation as processed and moves the user input that was up for
/// processing to the past user inputs.
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let false_value = conversation.contains_new_input();
/// conversation.add_user_input("This input will not be used");
/// let true_value = conversation.contains_new_input();
/// conversation.mark_processed();
/// let false_value = conversation.contains_new_input();
/// assert_eq!(conversation.past_user_inputs.len(), 1usize);
/// ```
pub fn mark_processed(&mut self) {
if self.new_user_input.is_some() {
self.past_user_inputs
@ -147,14 +283,49 @@ impl Conversation {
}
}
pub fn get_last_input(&self) -> &str {
/// Returns the last user input provided (including non-processed inputs).
///
/// # Returns
///
/// * `Option<&str>` representation of the last user input provided
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new_empty();
/// let none_value = conversation.get_last_input();
/// conversation.add_user_input("This input will not be used");
/// let last_provided_input = conversation.get_last_input();
/// assert_eq!(last_provided_input, Some("This input will not be used"));
/// ```
pub fn get_last_input(&self) -> Option<&str> {
if self.new_user_input.is_some() {
self.new_user_input.as_ref().unwrap().as_str()
Some(self.new_user_input.as_ref().unwrap().as_str())
} else {
self.past_user_inputs.last().unwrap().as_str()
if self.past_user_inputs.len() > 0 {
Some(self.past_user_inputs.last().unwrap().as_str())
} else {
None
}
}
}
/// Returns the last response generated by the system.
///
/// # Returns
///
/// * `Option<&str>` representation of the last response generated by the system.
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::Conversation;
///
/// let mut conversation = Conversation::new("Hi There");
/// let non_value = conversation.get_last_response();
/// ```
pub fn get_last_response(&self) -> Option<&str> {
if !self.generated_responses.is_empty() {
Some(self.generated_responses.last().unwrap().as_str())
@ -164,18 +335,50 @@ impl Conversation {
}
}
/// Data structure allowing the management of conversations and main input to the dialogue model.
/// It contains a `HashMap` of conversations with `UUID` keys
#[derive(Debug)]
pub struct ConversationManager {
conversations: HashMap<Uuid, Conversation>,
}
impl ConversationManager {
/// Build a new `ConversationManager`
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::ConversationManager;
///
/// let conversation_manager = ConversationManager::new();
/// ```
pub fn new() -> ConversationManager {
ConversationManager {
conversations: HashMap::new(),
}
}
/// Returns a list of the active conversations (containing new inputs to be processed by the model)
///
/// # Returns
///
/// * `(Vec<&Uuid>, Vec<&mut Conversation>)` Tuple of vectors with the active `UUID` and `Conversations`
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let empty_conversation = Conversation::new_empty();
/// let conversation_id = conversation_manager.add(conversation);
/// let empty_conversation_id = conversation_manager.add(empty_conversation);
///
/// let active_conversations = conversation_manager.get_active_conversations();
/// assert_eq!(active_conversations.0.len(), 1usize);
/// ```
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
let mut active_uuid = vec![];
let mut active_conversations = vec![];
@ -188,10 +391,46 @@ impl ConversationManager {
(active_uuid, active_conversations)
}
/// Returns a mutable reference to the conversation wih the provided UUID
///
/// # Arguments
///
/// * `uuid` - `&Uuid` of the conversation to retrieve
///
/// # Returns
///
/// * `Option<&mut Conversation>` Optional mutable reference to the conversation matching the UUID provided
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
///
/// let conversation_ref = conversation_manager.get(&conversation_id);
/// ```
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
self.conversations.get_mut(uuid)
}
/// Returns a HashMap containing references to all conversations stored in the manager
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
///
/// let all_conversations = conversation_manager.get_all();
/// ```
pub fn get_all(&mut self) -> HashMap<&Uuid, &Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
@ -200,11 +439,70 @@ impl ConversationManager {
output
}
/// Creates a conversation and add it to the conversation manager
///
/// # Arguments
///
/// * `text` - `&str` string slice with an original user input
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// ```
pub fn create(&mut self, text: &str) -> Uuid {
let conversation = Conversation::new(text.to_string());
let conversation = Conversation::new(text);
self.add(conversation)
}
/// Creates an empty conversation and add it to the conversation manager
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create_empty();
/// ```
pub fn create_empty(&mut self) -> Uuid {
let conversation = Conversation::new_empty();
self.add(conversation)
}
/// Adds an existing conversation to the conversation manager
///
/// # Arguments
///
/// * `conversation` - `Conversation` to be added to the conversation manager
///
/// # Returns
///
/// * `Uuid` for the conversation created
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation = Conversation::new("Hi there!");
/// let conversation_id = conversation_manager.add(conversation);
/// ```
pub fn add(&mut self, conversation: Conversation) -> Uuid {
let mut uuid = Uuid::new_v4();
while self.conversations.contains_key(&uuid) {
@ -214,16 +512,59 @@ impl ConversationManager {
uuid
}
/// Deregister a conversation from the conversation manager
///
/// # Arguments
///
/// * `uuid` - `&Uuid` of the conversation to deregister from the conversation manager
///
/// # Returns
///
/// * `Option<Conversation>` deregistered conversation
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// conversation_manager.remove(&conversation_id);
/// ```
pub fn remove(&mut self, uuid: &Uuid) -> Option<Conversation> {
self.conversations.remove(uuid)
}
pub fn clear(&mut self) {
/// Clear all conversations from the conversation manager, and returns the conversations and their
/// former UUID.
///
/// # Returns
///
/// * `HashMap<Uuid, Conversation>` deregistered conversations
///
/// # Example
///
/// ```no_run
/// use rust_bert::pipelines::conversation::{ConversationManager, Conversation};
///
/// let mut conversation_manager = ConversationManager::new();
///
/// let conversation_id = conversation_manager.create("Hi there!");
/// let conversations = conversation_manager.clear();
/// ```
pub fn clear(&mut self) -> HashMap<Uuid, Conversation> {
let mut output = HashMap::with_capacity(self.conversations.len());
for (uuid, conversation) in self.conversations.iter() {
output.insert(*uuid, conversation.clone());
}
self.conversations = HashMap::new();
output
}
}
/// # Conversation model
/// Processes a ConversationManager and generate system responses for active conversations.
pub struct ConversationModel {
model: GPT2Generator,
eos_token_id: i64,
@ -282,7 +623,7 @@ impl ConversationModel {
/// * `conversation_manager` - `&mut ConversationManager` Conversation manager keeping track of active conversations
///
/// # Returns
/// * `HashMap<&Uuid, &str>` Responses from the model for each acttive conversation, referenced by Uuid
/// * `HashMap<&Uuid, &str>` Responses from the model for each active conversation, referenced by Uuid
///
/// # Example
///
@ -362,6 +703,7 @@ impl ConversationModel {
}
fn concat_input_history(&self, inputs: Vec<Vec<i64>>, history: Vec<&Vec<i64>>) -> Tensor {
// Concatenates the history token indices with new user input
let max_len = self.model.get_config().max_length;
let pad_token = match self.model.get_pad_id() {
Some(value) => *value,
@ -435,6 +777,7 @@ impl ConversationModel {
}
fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
// Encode the user prompt into token ids
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());
tokens

View File

@ -2,6 +2,9 @@
//!
//! Based on Huggingface's pipelines, ready to use end-to-end NLP pipelines are available as part of this crate. The following capabilities are currently available:
//!
//! **Disclaimer**
//! The contributors of this repository are not responsible for any generation from the 3rd party utilization of the pretrained systems proposed herein.
//!
//! #### 1. Question Answering
//! Extractive question answering from a given question and context. DistilBERT model finetuned on SQuAD (Stanford Question Answering Dataset)
//!
@ -114,7 +117,34 @@
//! ```
//!
//!
//! #### 4. Natural Language Generation
//! #### 4. Dialogue Model
//! Conversation model based on Microsoft's [DialoGPT](https://github.com/microsoft/DialoGPT).
//! This pipeline allows the generation of single or multi-turn conversations between a human and a model.
//! The DialoGPT's page states that
//! > The human evaluation results indicate that the response generated from DialoGPT is comparable to human response quality
//! > under a single-turn conversation Turing test. ([DialoGPT repository](https://github.com/microsoft/DialoGPT))
//!
//! The model uses a `ConversationManager` to keep track of active conversations and generate responses to them.
//!
//! ```no_run
//! # fn main() -> failure::Fallible<()> {
//! use rust_bert::pipelines::conversation::{ConversationModel, ConversationManager};
//! let conversation_model = ConversationModel::new(Default::default())?;
//! let mut conversation_manager = ConversationManager::new();
//!
//! let conversation_id = conversation_manager.create("Going to the movies tonight - any suggestions?");
//! let output = conversation_model.generate_responses(&mut conversation_manager);
//! # Ok(())
//! # }
//! ```
//! Example output: \
//! ```no_run
//! # let output =
//! "The Big Lebowski."
//! # ;
//! ```
//!
//! #### 5. Natural Language Generation
//! Generate language based on a prompt. GPT2 and GPT available as base models.
//! Include techniques such as beam search, top-k and nucleus sampling, temperature setting and repetition penalty.
//! Supports batch generation of sentences from several prompts. Sequences will be left-padded with the model's padding token if present, the unknown token otherwise.
@ -145,7 +175,7 @@
//! # ;
//! ```
//!
//! #### 5. Sentiment analysis
//! #### 6. Sentiment analysis
//! Predicts the binary sentiment for a sentence. DistilBERT model finetuned on SST-2.
//! ```no_run
//! use rust_bert::pipelines::sentiment::SentimentModel;
@ -184,7 +214,7 @@
//! # ;
//! ```
//!
//! #### 6. Named Entity Recognition
//! #### 7. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;

View File

@ -304,10 +304,12 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> failure::Falli
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_single_multi_turn_conversation() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
@ -326,7 +328,7 @@ fn dialogpt_single_multi_turn_conversation() -> failure::Fallible<()> {
let _ = conversation_manager
.get(&conversation_id)
.unwrap()
.add_user_input(String::from("Is it an action movie?"));
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_id).unwrap(), &"It\'s a comedy.");
@ -339,10 +341,12 @@ fn dialogpt_single_multi_turn_conversation() -> failure::Fallible<()> {
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_multiple_multi_turn_conversation() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
@ -366,7 +370,7 @@ fn dialogpt_multiple_multi_turn_conversation() -> failure::Fallible<()> {
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input(String::from("Is it an action movie?"));
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(output.get(&conversation_1_id).unwrap(), &"It\'s a comedy.");
@ -379,10 +383,12 @@ fn dialogpt_multiple_multi_turn_conversation() -> failure::Fallible<()> {
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> failure::Fallible<()> {
// Set-up conversation model
let conversation_config = ConversationConfig {
do_sample: false,
device: Device::Cpu,
..Default::default()
};
let conversation_model = ConversationModel::new(conversation_config)?;
@ -407,7 +413,7 @@ fn dialogpt_multiple_multi_turn_conversation_with_conversation_deletion() -> fai
let _ = conversation_manager
.get(&conversation_2_id)
.unwrap()
.add_user_input(String::from("Why do you recommend it?"));
.add_user_input("Why do you recommend it?");
let output = conversation_model.generate_responses(&mut conversation_manager);
assert_eq!(output.len(), 1);
assert_eq!(