Update of conversation manager and conversations

This commit is contained in:
Guillaume B 2020-06-27 18:27:43 +02:00
parent c946791b58
commit 372c357463
3 changed files with 100 additions and 31 deletions

View File

@ -41,4 +41,5 @@ ordered-float = "1.0.2"
csv = "1.1.3"
reqwest = "0.10.4"
lazy_static = "1.4.0"
uuid = {version="0.8.1", features =["v4"]}
tokio = { version = "0.2.21", features = ["full"] }

View File

@ -23,22 +23,21 @@ fn main() -> failure::Fallible<()> {
};
let conversation_model = ConversationModel::new(conversation_config)?;
let mut conversation_manager = ConversationManager::new();
let conversation = Conversation::new(String::from(
"If you had all the money in the world, what would you buy?",
));
let mut conversation_manager = ConversationManager {
conversations: vec![conversation],
};
let conversation_uuid = conversation_manager.add(conversation);
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
conversation_manager.conversations[0]
.user_inputs
.push(String::from("Where?"));
let _ = conversation_manager
.get(&conversation_uuid)
.unwrap()
.add_user_input(String::from("Where?"));
let output = conversation_model.generate_responses(&mut conversation_manager);

View File

@ -32,7 +32,9 @@ use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenera
use itertools::Itertools;
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
use rust_tokenizers::{Tokenizer, TruncationStrategy};
use std::collections::HashMap;
use tch::{Device, Tensor};
use uuid::Uuid;
/// # Configuration for multi-turn classification
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
@ -108,43 +110,92 @@ impl Default for ConversationConfig {
#[derive(Debug)]
pub struct Conversation {
pub user_inputs: Vec<String>,
pub past_user_inputs: Vec<String>,
pub generated_responses: Vec<String>,
pub new_user_input: Option<String>,
pub history: Vec<i64>,
}
impl Conversation {
pub fn new(text: String) -> Conversation {
Conversation {
user_inputs: vec![text],
past_user_inputs: vec![],
generated_responses: vec![],
new_user_input: Some(text),
history: vec![],
}
}
pub fn add_user_input(&mut self, text: String) -> Result<(), &'static str> {
if self.new_user_input.is_some() {
Err("User input already provided for this conversation")
} else {
self.new_user_input = Some(text);
Ok(())
}
}
pub fn contains_new_input(&self) -> bool {
self.new_user_input.is_some()
}
pub fn mark_processed(&mut self) {
if self.new_user_input.is_some() {
self.past_user_inputs
.push(self.new_user_input.clone().unwrap());
self.new_user_input = None;
}
}
pub fn get_last_input(&self) -> &str {
if self.new_user_input.is_some() {
self.new_user_input.as_ref().unwrap().as_str()
} else {
self.past_user_inputs.last().unwrap().as_str()
}
}
pub fn get_last_response(&self) -> Option<&str> {
if !self.generated_responses.is_empty() {
Some(self.generated_responses.last().unwrap().as_str())
} else {
None
}
}
}
#[derive(Debug)]
pub struct ConversationManager {
pub conversations: Vec<Conversation>,
conversations: HashMap<Uuid, Conversation>,
}
impl ConversationManager {
pub fn get_last_user_input(&self) -> Vec<&str> {
self.conversations
.iter()
.map(|c| c.user_inputs.last().unwrap().as_str())
.collect_vec()
pub fn new() -> ConversationManager {
ConversationManager {
conversations: HashMap::new(),
}
}
pub fn get_last_generated_responses(&self) -> Vec<&str> {
self.conversations
.iter()
.map(|c| c.generated_responses.last().unwrap().as_str())
.collect_vec()
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
let mut active_uuid = vec![];
let mut active_conversations = vec![];
for (uuid, conversation) in self.conversations.iter_mut() {
if conversation.new_user_input.is_some() {
active_uuid.push(uuid);
active_conversations.push(conversation)
}
}
(active_uuid, active_conversations)
}
pub fn get_last_history(&self) -> Vec<&Vec<i64>> {
self.conversations.iter().map(|c| &c.history).collect_vec()
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
self.conversations.get_mut(uuid)
}
pub fn add(&mut self, conversation: Conversation) -> Uuid {
let uuid = Uuid::new_v4();
self.conversations.insert(uuid, conversation);
uuid
}
}
@ -227,9 +278,18 @@ impl ConversationModel {
pub fn generate_responses<'a>(
&self,
conversation_manager: &'a mut ConversationManager,
) -> Vec<&'a str> {
let texts = conversation_manager.get_last_user_input();
let history = conversation_manager.get_last_history();
) -> HashMap<&'a Uuid, &'a str> {
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
let texts = active_conversations
.iter()
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
.collect_vec();
let history = active_conversations
.iter()
.map(|c| &c.history)
.collect_vec();
let prompt_ids = self.encode_prompts(texts.as_slice());
let input_tensor = self.concat_input_history(prompt_ids, history);
@ -237,26 +297,35 @@ impl ConversationModel {
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
self.clean_padding_indices(&mut generated);
for (conversation_index, generated_sequence) in generated.into_iter().enumerate() {
conversation_manager.conversations[conversation_index]
let mut output = HashMap::with_capacity(active_uuid.len());
for ((conversation, generated_sequence), uuid) in active_conversations
.into_iter()
.zip(generated.into_iter())
.zip(active_uuid.into_iter())
{
conversation
.generated_responses
.push(self.model.get_tokenizer().decode(
generated_sequence[input_length..].to_vec(),
true,
true,
));
conversation_manager.conversations[conversation_index].history = generated_sequence;
conversation.history = generated_sequence;
conversation.mark_processed();
output.insert(uuid, conversation.get_last_response().unwrap());
}
conversation_manager.get_last_generated_responses()
output
}
fn clean_padding_indices(&self, history: &mut Vec<Vec<i64>>) {
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) {
// In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
let pad_token = match self.model.get_pad_id() {
Some(value) => *value,
None => self.eos_token_id,
};
for sequence_history in history {
for sequence_history in model_output {
let index = sequence_history
.iter()
.rev()