mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Update of conversation manager and conversations
This commit is contained in:
parent
c946791b58
commit
372c357463
@ -41,4 +41,5 @@ ordered-float = "1.0.2"
|
|||||||
csv = "1.1.3"
|
csv = "1.1.3"
|
||||||
reqwest = "0.10.4"
|
reqwest = "0.10.4"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
|
uuid = {version="0.8.1", features =["v4"]}
|
||||||
tokio = { version = "0.2.21", features = ["full"] }
|
tokio = { version = "0.2.21", features = ["full"] }
|
@ -23,22 +23,21 @@ fn main() -> failure::Fallible<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let conversation_model = ConversationModel::new(conversation_config)?;
|
let conversation_model = ConversationModel::new(conversation_config)?;
|
||||||
|
let mut conversation_manager = ConversationManager::new();
|
||||||
|
|
||||||
let conversation = Conversation::new(String::from(
|
let conversation = Conversation::new(String::from(
|
||||||
"If you had all the money in the world, what would you buy?",
|
"If you had all the money in the world, what would you buy?",
|
||||||
));
|
));
|
||||||
|
let conversation_uuid = conversation_manager.add(conversation);
|
||||||
let mut conversation_manager = ConversationManager {
|
|
||||||
conversations: vec![conversation],
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||||
|
|
||||||
println!("{:?}", output);
|
println!("{:?}", output);
|
||||||
|
|
||||||
conversation_manager.conversations[0]
|
let _ = conversation_manager
|
||||||
.user_inputs
|
.get(&conversation_uuid)
|
||||||
.push(String::from("Where?"));
|
.unwrap()
|
||||||
|
.add_user_input(String::from("Where?"));
|
||||||
|
|
||||||
let output = conversation_model.generate_responses(&mut conversation_manager);
|
let output = conversation_model.generate_responses(&mut conversation_manager);
|
||||||
|
|
||||||
|
@ -32,7 +32,9 @@ use crate::pipelines::generation::{GPT2Generator, GenerateConfig, LanguageGenera
|
|||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
|
use rust_tokenizers::preprocessing::tokenizer::tokenization_utils::truncate_sequences;
|
||||||
use rust_tokenizers::{Tokenizer, TruncationStrategy};
|
use rust_tokenizers::{Tokenizer, TruncationStrategy};
|
||||||
|
use std::collections::HashMap;
|
||||||
use tch::{Device, Tensor};
|
use tch::{Device, Tensor};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// # Configuration for multi-turn classification
|
/// # Configuration for multi-turn classification
|
||||||
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
|
||||||
@ -108,43 +110,92 @@ impl Default for ConversationConfig {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Conversation {
|
pub struct Conversation {
|
||||||
pub user_inputs: Vec<String>,
|
pub past_user_inputs: Vec<String>,
|
||||||
pub generated_responses: Vec<String>,
|
pub generated_responses: Vec<String>,
|
||||||
|
pub new_user_input: Option<String>,
|
||||||
pub history: Vec<i64>,
|
pub history: Vec<i64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Conversation {
|
impl Conversation {
|
||||||
pub fn new(text: String) -> Conversation {
|
pub fn new(text: String) -> Conversation {
|
||||||
Conversation {
|
Conversation {
|
||||||
user_inputs: vec![text],
|
past_user_inputs: vec![],
|
||||||
generated_responses: vec![],
|
generated_responses: vec![],
|
||||||
|
new_user_input: Some(text),
|
||||||
history: vec![],
|
history: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_user_input(&mut self, text: String) -> Result<(), &'static str> {
|
||||||
|
if self.new_user_input.is_some() {
|
||||||
|
Err("User input already provided for this conversation")
|
||||||
|
} else {
|
||||||
|
self.new_user_input = Some(text);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contains_new_input(&self) -> bool {
|
||||||
|
self.new_user_input.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mark_processed(&mut self) {
|
||||||
|
if self.new_user_input.is_some() {
|
||||||
|
self.past_user_inputs
|
||||||
|
.push(self.new_user_input.clone().unwrap());
|
||||||
|
self.new_user_input = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_last_input(&self) -> &str {
|
||||||
|
if self.new_user_input.is_some() {
|
||||||
|
self.new_user_input.as_ref().unwrap().as_str()
|
||||||
|
} else {
|
||||||
|
self.past_user_inputs.last().unwrap().as_str()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_last_response(&self) -> Option<&str> {
|
||||||
|
if !self.generated_responses.is_empty() {
|
||||||
|
Some(self.generated_responses.last().unwrap().as_str())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ConversationManager {
|
pub struct ConversationManager {
|
||||||
pub conversations: Vec<Conversation>,
|
conversations: HashMap<Uuid, Conversation>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConversationManager {
|
impl ConversationManager {
|
||||||
pub fn get_last_user_input(&self) -> Vec<&str> {
|
pub fn new() -> ConversationManager {
|
||||||
self.conversations
|
ConversationManager {
|
||||||
.iter()
|
conversations: HashMap::new(),
|
||||||
.map(|c| c.user_inputs.last().unwrap().as_str())
|
}
|
||||||
.collect_vec()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_last_generated_responses(&self) -> Vec<&str> {
|
pub fn get_active_conversations(&mut self) -> (Vec<&Uuid>, Vec<&mut Conversation>) {
|
||||||
self.conversations
|
let mut active_uuid = vec![];
|
||||||
.iter()
|
let mut active_conversations = vec![];
|
||||||
.map(|c| c.generated_responses.last().unwrap().as_str())
|
for (uuid, conversation) in self.conversations.iter_mut() {
|
||||||
.collect_vec()
|
if conversation.new_user_input.is_some() {
|
||||||
|
active_uuid.push(uuid);
|
||||||
|
active_conversations.push(conversation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(active_uuid, active_conversations)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_last_history(&self) -> Vec<&Vec<i64>> {
|
pub fn get(&mut self, uuid: &Uuid) -> Option<&mut Conversation> {
|
||||||
self.conversations.iter().map(|c| &c.history).collect_vec()
|
self.conversations.get_mut(uuid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add(&mut self, conversation: Conversation) -> Uuid {
|
||||||
|
let uuid = Uuid::new_v4();
|
||||||
|
self.conversations.insert(uuid, conversation);
|
||||||
|
uuid
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,9 +278,18 @@ impl ConversationModel {
|
|||||||
pub fn generate_responses<'a>(
|
pub fn generate_responses<'a>(
|
||||||
&self,
|
&self,
|
||||||
conversation_manager: &'a mut ConversationManager,
|
conversation_manager: &'a mut ConversationManager,
|
||||||
) -> Vec<&'a str> {
|
) -> HashMap<&'a Uuid, &'a str> {
|
||||||
let texts = conversation_manager.get_last_user_input();
|
let (active_uuid, active_conversations) = conversation_manager.get_active_conversations();
|
||||||
let history = conversation_manager.get_last_history();
|
|
||||||
|
let texts = active_conversations
|
||||||
|
.iter()
|
||||||
|
.map(|c| c.new_user_input.as_ref().unwrap().as_str())
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
let history = active_conversations
|
||||||
|
.iter()
|
||||||
|
.map(|c| &c.history)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
let prompt_ids = self.encode_prompts(texts.as_slice());
|
let prompt_ids = self.encode_prompts(texts.as_slice());
|
||||||
let input_tensor = self.concat_input_history(prompt_ids, history);
|
let input_tensor = self.concat_input_history(prompt_ids, history);
|
||||||
@ -237,26 +297,35 @@ impl ConversationModel {
|
|||||||
|
|
||||||
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
|
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
|
||||||
self.clean_padding_indices(&mut generated);
|
self.clean_padding_indices(&mut generated);
|
||||||
for (conversation_index, generated_sequence) in generated.into_iter().enumerate() {
|
|
||||||
conversation_manager.conversations[conversation_index]
|
let mut output = HashMap::with_capacity(active_uuid.len());
|
||||||
|
|
||||||
|
for ((conversation, generated_sequence), uuid) in active_conversations
|
||||||
|
.into_iter()
|
||||||
|
.zip(generated.into_iter())
|
||||||
|
.zip(active_uuid.into_iter())
|
||||||
|
{
|
||||||
|
conversation
|
||||||
.generated_responses
|
.generated_responses
|
||||||
.push(self.model.get_tokenizer().decode(
|
.push(self.model.get_tokenizer().decode(
|
||||||
generated_sequence[input_length..].to_vec(),
|
generated_sequence[input_length..].to_vec(),
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
));
|
));
|
||||||
conversation_manager.conversations[conversation_index].history = generated_sequence;
|
conversation.history = generated_sequence;
|
||||||
|
conversation.mark_processed();
|
||||||
|
output.insert(uuid, conversation.get_last_response().unwrap());
|
||||||
}
|
}
|
||||||
conversation_manager.get_last_generated_responses()
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_padding_indices(&self, history: &mut Vec<Vec<i64>>) {
|
fn clean_padding_indices(&self, model_output: &mut Vec<Vec<i64>>) {
|
||||||
// In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
|
// In case inputs are sent as batch, this cleans the padding indices in the history for shorter outputs
|
||||||
let pad_token = match self.model.get_pad_id() {
|
let pad_token = match self.model.get_pad_id() {
|
||||||
Some(value) => *value,
|
Some(value) => *value,
|
||||||
None => self.eos_token_id,
|
None => self.eos_token_id,
|
||||||
};
|
};
|
||||||
for sequence_history in history {
|
for sequence_history in model_output {
|
||||||
let index = sequence_history
|
let index = sequence_history
|
||||||
.iter()
|
.iter()
|
||||||
.rev()
|
.rev()
|
||||||
|
Loading…
Reference in New Issue
Block a user