Initial implementation of banned tokens detection

This commit is contained in:
Guillaume B 2020-03-12 18:49:45 +01:00
parent 61721ee2b8
commit b6e186b7b0
3 changed files with 46 additions and 6 deletions

View File

@ -27,4 +27,5 @@ tch = "0.1.6"
serde_json = "1.0.45"
serde = {version = "1.0.104", features = ["derive"]}
failure = "0.1.6"
dirs = "2.0"
dirs = "2.0"
itertools = "0.9.0"

View File

@ -35,9 +35,9 @@ fn main() -> failure::Fallible<()> {
let model = OpenAIGenerator::new(vocab_path, merges_path, config_path, weights_path, device)?;
// let model = GPT2Generator::new(vocab_path, merges_path, config_path, weights_path, device)?;
let input_context = "The dog";
let output = model.generate(Some(input_context), 0, 40, true, false,5, 1.0,
50, 1.0, 1.1, 1.0, 0, 3, None);
let input_context = "Dog Dog Dog The The The The Dog Dog";
let output = model.generate(Some(input_context), 0, 40, true, false,1, 1.0,
50, 1.0, 1.1, 1.0, 3, 1, None);
// println!("{:?}", output);
// output.print();
Ok(())

View File

@ -21,6 +21,8 @@ use crate::{Gpt2Config, GPT2LMHeadModel};
use crate::common::config::Config;
use rust_tokenizers::tokenization_utils::truncate_sequences;
use tch::kind::Kind::Int64;
use std::collections::HashMap;
use itertools::Itertools;
pub struct OpenAIGenerator {
model: OpenAIGPTLMHeadModel,
@ -137,6 +139,40 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
}
}
fn get_banned_tokens(&self, input_ids: Tensor, no_repeat_ngram_size: i64, cur_len: i64) -> Vec<Vec<i64>> {
if cur_len + 1 < no_repeat_ngram_size {
vec!(vec!())
} else {
let mut generated_ngrams: Vec<HashMap<Vec<i64>, Vec<i64>>> = vec!();
for hypothesis_index in 0..*input_ids.size().first().unwrap() {
let hypothesis_input_ids = input_ids.get(hypothesis_index);
let mut generated_ngram: HashMap<Vec<i64>, Vec<i64>> = HashMap::new();
let mut input: Vec<i64> = (0..hypothesis_input_ids.size1().unwrap()).collect();
let ngram_indices: Vec<(i64, i64)> = input
.windows(3)
.map(|win| (*win.first().unwrap(), *win.last().unwrap()))
.collect();
for ngram in ngram_indices.into_iter() {
let ngram = hypothesis_input_ids
.slice(0, ngram.0, ngram.1 + 1, 1)
.iter::<i64>()
.unwrap()
.collect::<Vec<i64>>();
let key = ngram[..ngram.len() - 1].to_vec();
let value = *ngram.last().unwrap();
if generated_ngram.contains_key(&key) {
generated_ngram.get_mut(&key).unwrap().push(value)
} else {
generated_ngram.insert(key, vec!(value));
}
}
generated_ngrams.push(generated_ngram);
}
println!("{:?}", generated_ngrams);
vec!(vec!())
}
}
// fn top_k_top_p_filtering(&self, logits: &mut Tensor, top_k: u64, top_p: f64, filter_value)
fn generate(&self, prompt_text: Option<&str>, min_length: u64, max_length: u64, do_sample: bool, early_stopping: bool, num_beams: u64, temperature: f64, top_k: u64,
@ -232,7 +268,8 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
let input_ids_back = input_ids.copy();
// ToDo: change threshold to while cur_len < max_len
while cur_len < 1 {
let mut counter = 0;
while counter < 1 {
let (prepared_input, prepared_past) = self.prepare_inputs_for_generation(input_ids.copy(), past, attention_mask.copy());
let temp = self.get_model().forward_t(&Some(prepared_input), &prepared_past, &None, &None, &None, &None, false).unwrap();
outputs = temp.0;
@ -249,11 +286,13 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
}
};
self.get_banned_tokens(input_ids, no_repeat_ngram_size as i64, cur_len as i64);
// ToDo: remove when loop is fixed
input_ids = input_ids_back.copy();
cur_len += 1;
counter += 1;
}
}
}