Merge remote-tracking branch 'origin/master' into kind_reword

# Conflicts:
#	Cargo.toml
#	src/t5/layer_norm.rs
This commit is contained in:
Guillaume Becquin 2021-11-09 16:00:21 +01:00
commit 73f017d0f7
38 changed files with 200 additions and 242 deletions

View File

@ -2,6 +2,10 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Changed
- Updated to `tch` 1.6.0 (libtorch 1.10)
- (BREAKING) Simplified the generics for multiple library traits taking as a rule `&[AsRef<str>]` or `&str` as inputs (no longer accepts owned types `Vec` and `String`)
## Added
- (BREAKING) Support for `bad_word_ids` generation, allowing to ban a set of word ids for all model supporting text generation
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)

View File

@ -57,21 +57,21 @@ all-tests = []
features = ["doc-only"]
[dependencies]
rust_tokenizers = "~6.2.4"
tch = { version = "0.5.0", path = "E:/Coding/tch-rs" }
serde_json = "1.0.66"
serde = { version = "1.0.129", features = ["derive"] }
dirs = "3.0.2"
ordered-float = "2.7.0"
rust_tokenizers = "~7.0.0"
tch = "~0.6.1"
serde_json = "1.0.68"
serde = { version = "1.0.130", features = ["derive"] }
dirs = "4.0.0"
ordered-float = "2.8.0"
cached-path = "0.5.1"
lazy_static = "1.4.0"
uuid = { version = "0.8.2", features = ["v4"] }
thiserror = "1.0.26"
thiserror = "1.0.30"
half = "1.7.1"
[dev-dependencies]
anyhow = "1.0.43"
anyhow = "1.0.44"
csv = "1.1.6"
criterion = "0.3.5"
torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" }
torch-sys = "~0.6.1"
tempfile = "3.2.0"

View File

@ -71,8 +71,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
### Manual installation (recommended)
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.9.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.9.0%2Bcu111.zip` for a Linux version with CUDA11.
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:

View File

@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -56,9 +56,9 @@ fn main() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
for sentence in outputs {
println!("{}", sentence);

View File

@ -1205,17 +1205,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -19,7 +19,7 @@ pub fn _mish(x: &Tensor) -> Tensor {
}
pub fn _gelu_new(x: &Tensor) -> Tensor {
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
x * 0.5 * (((x.pow_tensor_scalar(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
}
pub fn _tanh(x: &Tensor) -> Tensor {

View File

@ -81,8 +81,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.8.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:

View File

@ -801,17 +801,17 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -979,17 +979,17 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
}
}
fn encode_prompt_text<'a, T>(
fn encode_prompt_text<S>(
&self,
prompt_text: T,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
T: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -1011,17 +1011,17 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -774,17 +774,17 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -494,13 +494,16 @@ impl TokenizerOption {
}
/// Interface method
pub fn encode_list(
pub fn encode_list<S>(
&self,
text_list: &[&str],
text_list: &[S],
max_len: usize,
truncation_strategy: &TruncationStrategy,
stride: usize,
) -> Vec<TokenizedInput> {
) -> Vec<TokenizedInput>
where
S: AsRef<str> + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
tokenizer,
@ -809,7 +812,10 @@ impl TokenizerOption {
}
/// Interface method to tokenization
pub fn tokenize_list(&self, text: &[&str]) -> Vec<Vec<String>> {
pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
where
S: AsRef<str> + Sync,
{
match *self {
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
@ -837,7 +843,7 @@ impl TokenizerOption {
/// Interface method to decoding
pub fn decode(
&self,
token_ids: Vec<i64>,
token_ids: &[i64],
skip_special_tokens: bool,
clean_up_tokenization_spaces: bool,
) -> String {
@ -964,10 +970,9 @@ impl TokenizerOption {
}
/// Interface method to convert tokens to ids
pub fn convert_tokens_to_ids<S, ST>(&self, tokens: S) -> Vec<i64>
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
where
S: AsRef<[ST]>,
ST: AsRef<str>,
S: AsRef<str>,
{
match *self {
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),

View File

@ -416,22 +416,20 @@ impl Conversation {
/// let _ = conversation_manager
/// .get(&conversation_1_id)
/// .unwrap()
/// .load_from_history(history, encoded_history);
/// .load_from_history(&history, &encoded_history);
/// # Ok(())
/// # }
/// ```
pub fn load_from_history<ST, SI, STR, SIN>(&mut self, texts: ST, ids: SI)
pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
where
ST: AsRef<[STR]>,
SI: AsRef<[SIN]>,
STR: AsRef<str>,
SIN: AsRef<[i64]>,
S: AsRef<str>,
SI: AsRef<[i64]>,
{
for (round_text, round_ids) in texts.as_ref().iter().zip(ids.as_ref().iter()) {
for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
self.append(round_text.as_ref(), round_ids.as_ref());
}
if texts.as_ref().len() / 2 == 1 {
if texts.len() / 2 == 1 {
self.history.pop();
}
}
@ -840,11 +838,11 @@ impl ConversationModel {
let generated_response = &generated_sequence[input_length - removed_padding.0..];
conversation
.generated_responses
.push(self.model.get_tokenizer().decode(
generated_response.to_vec(),
true,
true,
));
.push(
self.model
.get_tokenizer()
.decode(generated_response, true, true),
);
conversation.history.push(conversation_promp_ids);
conversation.history.push(generated_response.to_vec());
conversation.mark_processed();

View File

@ -41,7 +41,7 @@
//! let input_context = "The dog";
//! let second_input_context = "The cat was";
//! let output = gpt2_generator.generate(
//! Some(vec![input_context, second_input_context]),
//! Some(&[input_context, second_input_context]),
//! None,
//! min_length,
//! max_length,
@ -321,16 +321,16 @@ pub(crate) mod private_generation_utils {
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref());
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
let token_ids = tokens
.into_iter()
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
@ -934,7 +934,7 @@ pub(crate) mod private_generation_utils {
current_length += 1;
}
let scores_output = scores_output.map(|scores_tensor| {
(scores_tensor / sentence_lengths.pow(gen_opt.length_penalty))
(scores_tensor / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>()
@ -1504,7 +1504,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
@ -1530,9 +1530,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// ]
/// # ;
/// ```
fn generate<'a, S>(
fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
@ -1543,7 +1543,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output_scores: bool,
) -> Vec<GeneratedTextOutput>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let indices_outputs = self.generate_indices(
prompt_texts,
@ -1561,7 +1561,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output.push(GeneratedTextOutput {
text: self
._get_tokenizer()
.decode(generated_sequence.indices, true, true),
.decode(&generated_sequence.indices, true, true),
score: generated_sequence.score,
});
}
@ -1636,7 +1636,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,
@ -1649,9 +1649,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// # Ok(())
/// # }
/// ```
fn generate_indices<'a, S>(
fn generate_indices<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: impl Into<Option<i64>>,
max_length: impl Into<Option<i64>>,
@ -1662,7 +1662,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
output_scores: bool,
) -> Vec<GeneratedIndicesOutput>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
@ -1771,7 +1771,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
/// }
///
/// let output = gpt2_generator.generate_indices(
/// Some(vec![input_context, second_input_context]),
/// Some(&[input_context, second_input_context]),
/// attention_mask,
/// min_length,
/// max_length,

View File

@ -191,9 +191,9 @@ impl NERModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<Entity>>
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)

View File

@ -195,9 +195,9 @@ impl POSModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<POSTag>>
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<POSTag>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
self.token_classification_model
.predict(input, true, false)

View File

@ -254,13 +254,13 @@ impl SummarizationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate<'a, S>(
pub fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::Bart(ref model) => model
@ -406,22 +406,18 @@ impl SummarizationModel {
/// # }
/// ```
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
pub fn summarize<'a, S>(&self, texts: S) -> Vec<String>
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match &self.prefix {
None => self.model.generate(Some(texts), None),
Some(prefix) => {
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{}{}", prefix, text))
.map(|text| format!("{}{}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
None,
)
self.model.generate(Some(&texts), None)
}
}
}

View File

@ -238,15 +238,15 @@ impl TextGenerationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate_indices<'a, S>(
pub fn generate_indices<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
min_length: Option<i64>,
max_length: Option<i64>,
) -> Vec<Vec<i64>>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::GPT(ref model) => model
@ -460,9 +460,9 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
/// # Ok(())
/// # }
/// ```
pub fn generate<'a, S>(&self, texts: S, prefix: impl Into<Option<&'a str>>) -> Vec<String>
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
(Some(query_prefix), _) => (
@ -478,10 +478,10 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
let texts = texts
.as_ref()
.iter()
.map(|text| format!("{} {}", prefix, text))
.map(|text| format!("{} {}", prefix, text.as_ref()))
.collect::<Vec<String>>();
self.model.generate_indices(
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
Some(&texts),
None,
Some(self.min_length + prefix_length),
Some(self.max_length + prefix_length),
@ -493,14 +493,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
let mut output = Vec::with_capacity(generated_indices.len());
for generated_sequence in generated_indices {
output.push(self.model.get_tokenizer().decode(
if prefix_length.is_some() {
generated_sequence
.into_iter()
.skip(prefix_length.unwrap_or(0) as usize)
.collect::<Vec<i64>>()
} else {
generated_sequence
},
&generated_sequence[prefix_length.unwrap_or(0) as usize..],
true,
true,
));

View File

@ -776,17 +776,16 @@ impl TokenClassificationModel {
/// # Ok(())
/// # }
/// ```
pub fn predict<'a, S>(
pub fn predict<S>(
&self,
input: S,
input: &[S],
consolidate_sub_tokens: bool,
return_special: bool,
) -> Vec<Vec<Token>>
where
S: AsRef<[&'a str]>,
S: AsRef<str>,
{
let mut features: Vec<InputFeature> = input
.as_ref()
.iter()
.enumerate()
.map(|(example_index, example)| self.generate_features(example, example_index))
@ -794,7 +793,7 @@ impl TokenClassificationModel {
.collect();
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();
for example_idx in 0..input.as_ref().len() {
for example_idx in 0..input.len() {
example_tokens_map.insert(example_idx, Vec::new());
}
let mut start = 0usize;
@ -820,7 +819,8 @@ impl TokenClassificationModel {
let labels = label_indices.get(sentence_idx);
let feature = &features[sentence_idx as usize];
let sentence_reference_flag = &feature.reference_feature;
let original_chars = input.as_ref()[feature.example_index]
let original_chars = input[feature.example_index]
.as_ref()
.chars()
.collect::<Vec<char>>();
let mut word_idx: u16 = 0;
@ -935,19 +935,19 @@ impl TokenClassificationModel {
let text = match offsets {
None => match self.tokenizer {
TokenizerOption::Bert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Roberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLMRoberta(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::Albert(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
TokenizerOption::XLNet(ref tokenizer) => {
Tokenizer::decode(tokenizer, vec![token_id], false, false)
Tokenizer::decode(tokenizer, &[token_id], false, false)
}
_ => panic!(
"Token classification not implemented for {:?}!",

View File

@ -55,9 +55,13 @@
//! let source_sentence = "This sentence will be translated in multiple languages.";
//!
//! let mut outputs = Vec::new();
//! outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
//! outputs.extend(model.translate(
//! &[source_sentence],
//! Language::English,
//! Language::Spanish,
//! )?);
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
//!
//! for sentence in outputs {
//! println!("{}", sentence);

View File

@ -364,7 +364,7 @@ impl TranslationModelBuilder {
{
match (source_languages.as_slice(), target_languages.as_slice()) {
([Language::English], [Language::German]) => {
get_marian_resources!(ENGLISH2RUSSIAN)
get_marian_resources!(ENGLISH2GERMAN)
}
([Language::English], [Language::Russian]) => {
get_marian_resources!(ENGLISH2RUSSIAN)

View File

@ -586,8 +586,7 @@ impl TranslationOption {
if !supported_source_languages.contains(source_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
source_language.to_string(),
supported_source_languages
source_language, supported_source_languages
)));
}
}
@ -596,8 +595,7 @@ impl TranslationOption {
if !supported_target_languages.contains(target_language) {
return Err(RustBertError::ValueError(format!(
"{} not in list of supported languages: {:?}",
target_language.to_string(),
supported_target_languages
target_language, supported_target_languages
)));
}
}
@ -630,7 +628,7 @@ impl TranslationOption {
Some(format!(
"translate {} to {}:",
match source_language {
Some(value) => value.to_string(),
Some(value) => value,
None => {
return Err(RustBertError::ValueError(
"Missing source language for T5".to_string(),
@ -638,7 +636,7 @@ impl TranslationOption {
}
},
match target_language {
Some(value) => value.to_string(),
Some(value) => value,
None => {
return Err(RustBertError::ValueError(
"Missing target language for T5".to_string(),
@ -665,7 +663,7 @@ impl TranslationOption {
)),
if let Some(target_language) = target_language {
Some(
model._get_tokenizer().convert_tokens_to_ids([format!(
model._get_tokenizer().convert_tokens_to_ids(&[format!(
">>{}<<",
target_language.get_iso_639_1_code()
)])[0],
@ -705,9 +703,8 @@ impl TranslationOption {
if let Some(target_language) = target_language {
let language_code = target_language.get_iso_639_1_code();
Some(
model
._get_tokenizer()
.convert_tokens_to_ids([match language_code.len() {
model._get_tokenizer().convert_tokens_to_ids(&[
match language_code.len() {
2 => format!(">>{}.<<", language_code),
3 => format!(">>{}<<", language_code),
_ => {
@ -715,7 +712,8 @@ impl TranslationOption {
"Invalid ISO 639-3 code".to_string(),
));
}
}])[0],
},
])[0],
)
} else {
return Err(RustBertError::ValueError(format!(
@ -730,14 +728,14 @@ impl TranslationOption {
}
/// Interface method to generate() of the particular models.
pub fn generate<'a, S>(
pub fn generate<S>(
&self,
prompt_texts: Option<S>,
prompt_texts: Option<&[S]>,
attention_mask: Option<Tensor>,
forced_bos_token_id: Option<i64>,
) -> Vec<String>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
match *self {
Self::Marian(ref model) => model
@ -927,14 +925,14 @@ impl TranslationModel {
/// # Ok(())
/// # }
/// ```
pub fn translate<'a, S>(
pub fn translate<S>(
&self,
texts: S,
texts: &[S],
source_language: impl Into<Option<Language>>,
target_language: impl Into<Option<Language>>,
) -> Result<Vec<String>, RustBertError>
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
source_language.into().as_ref(),
@ -946,15 +944,10 @@ impl TranslationModel {
Ok(match prefix {
Some(value) => {
let texts = texts
.as_ref()
.iter()
.map(|&v| format!("{}{}", value, v))
.map(|v| format!("{}{}", value, v.as_ref()))
.collect::<Vec<String>>();
self.model.generate(
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
None,
forced_bos_token_id,
)
self.model.generate(Some(&texts), None, forced_bos_token_id)
}
None => self.model.generate(Some(texts), None, forced_bos_token_id),
})

View File

@ -1063,17 +1063,17 @@ impl
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -33,10 +33,10 @@ impl T5LayerNorm {
impl Module for T5LayerNorm {
fn forward(&self, x: &Tensor) -> Tensor {
let input_type = x.kind();
let variance = x
.to_kind(Kind::Float)
.pow(2.0_f64)
.mean_dim(&[-1], true, Kind::Float);
let variance =
x.to_kind(Kind::Float)
.pow_tensor_scalar(2.0_f64)
.mean_dim(&[-1], true, Kind::Float);
let x = x * (variance + self.epsilon).rsqrt();
if input_type != Kind::Float {
(&self.weight * x).to_kind(input_type)

View File

@ -858,17 +858,17 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
}
}
fn encode_prompt_text<'a, S>(
fn encode_prompt_text<S>(
&self,
prompt_text: S,
prompt_text: &[S],
max_len: i64,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<[&'a str]>,
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text.as_ref(),
prompt_text,
max_len as usize,
&TruncationStrategy::LongestFirst,
0,

View File

@ -238,37 +238,36 @@ impl XLNetRelativeAttention {
target_mapping: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<Tensor>) {
if let Some(g) = g {
let cat_value = if let Some(mems) = &layer_state {
if mems.prev_content.size().len() > 1 {
Some(Tensor::cat(&[&mems.prev_content, h], 0))
} else {
None
}
let cat_value = if let Some(mems) = &layer_state {
if mems.prev_content.size().len() > 1 {
Some(Tensor::cat(&[&mems.prev_content, h], 0))
} else {
None
};
let cat = match &cat_value {
Some(value) => value,
None => h,
};
}
} else {
None
};
let cat = match &cat_value {
Some(value) => value,
None => h,
};
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]);
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]);
let (attention_vec_h, attention_probas_h) = self.rel_attention_core(
&q_head_h,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_h,
train,
);
let output_h = self.post_attention(h, &attention_vec_h, true, train);
let (attention_vec_h, attention_probas_h) = self.rel_attention_core(
&q_head_h,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_h,
train,
);
let output_h = self.post_attention(h, &attention_vec_h, true, train);
let (output_g, attention_probas_g) = if let Some(g) = g {
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]);
let (attention_vec_g, attention_probas_g) = match target_mapping {
@ -299,44 +298,10 @@ impl XLNetRelativeAttention {
};
let output_g = self.post_attention(g, &attention_vec_g, true, train);
(
output_h,
Some(output_g),
attention_probas_h,
attention_probas_g,
)
(Some(output_g), attention_probas_g)
} else {
let cat_value = if let Some(mems) = &layer_state {
if mems.prev_content.size().len() > 1 {
Some(Tensor::cat(&[&mems.prev_content, h], 0))
} else {
None
}
} else {
None
};
let cat = match &cat_value {
Some(value) => value,
None => h,
};
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]);
let (attention_vec, attention_probas) = self.rel_attention_core(
&q_head_h,
&k_head_h,
&v_head_h,
&k_head_r,
seg_mat,
attn_mask_h,
train,
);
let output_h = self.post_attention(h, &attention_vec, true, train);
(output_h, None, attention_probas, None)
}
(None, None)
};
(output_h, output_g, attention_probas_h, attention_probas_g)
}
}

View File

@ -80,7 +80,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
match model_output.cache {

View File

@ -82,7 +82,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
match model_output.cache {

View File

@ -72,7 +72,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
let next_score = model_output
.lm_logits
.get(0)

View File

@ -274,7 +274,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),

View File

@ -107,9 +107,9 @@ fn m2m100_translation() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -73,9 +73,9 @@ fn mbart_translation() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -182,7 +182,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),

View File

@ -83,7 +83,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
.get(-1)
.argmax(-1, true)
.int64_value(&[0]);
let next_word = tokenizer.decode(vec![next_word_id], true, true);
let next_word = tokenizer.decode(&[next_word_id], true, true);
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
assert!(

View File

@ -44,9 +44,9 @@ fn test_translation_t5() -> anyhow::Result<()> {
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
assert_eq!(outputs.len(), 3);
assert_eq!(

View File

@ -331,7 +331,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
let inputs = ["Very positive sentence", "Second sentence input"];
let tokenized_input = tokenizer.encode_pair_list(
inputs
&inputs
.iter()
.map(|&inp| (prompt, inp))
.collect::<Vec<(&str, &str)>>(),