mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Merge remote-tracking branch 'origin/master' into kind_reword
# Conflicts: # Cargo.toml # src/t5/layer_norm.rs
This commit is contained in:
commit
73f017d0f7
@ -2,6 +2,10 @@
|
||||
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
||||
|
||||
## [Unreleased]
|
||||
## Changed
|
||||
- Updated to `tch` 1.6.0 (libtorch 1.10)
|
||||
- (BREAKING) Simplified the generics for multiple library traits taking as a rule `&[AsRef<str>]` or `&str` as inputs (no longer accepts owned types `Vec` and `String`)
|
||||
|
||||
## Added
|
||||
- (BREAKING) Support for `bad_word_ids` generation, allowing to ban a set of word ids for all model supporting text generation
|
||||
- Support for half-precision mode for all models (reducing memory footprint). A model can be converted to half-precision by calling the `half()` method on the `VarStore` is it currently stored in. Half-precision Torch kernels are not available for CPU (limited to CUDA devices)
|
||||
|
18
Cargo.toml
18
Cargo.toml
@ -57,21 +57,21 @@ all-tests = []
|
||||
features = ["doc-only"]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = "~6.2.4"
|
||||
tch = { version = "0.5.0", path = "E:/Coding/tch-rs" }
|
||||
serde_json = "1.0.66"
|
||||
serde = { version = "1.0.129", features = ["derive"] }
|
||||
dirs = "3.0.2"
|
||||
ordered-float = "2.7.0"
|
||||
rust_tokenizers = "~7.0.0"
|
||||
tch = "~0.6.1"
|
||||
serde_json = "1.0.68"
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
dirs = "4.0.0"
|
||||
ordered-float = "2.8.0"
|
||||
cached-path = "0.5.1"
|
||||
lazy_static = "1.4.0"
|
||||
uuid = { version = "0.8.2", features = ["v4"] }
|
||||
thiserror = "1.0.26"
|
||||
thiserror = "1.0.30"
|
||||
half = "1.7.1"
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0.43"
|
||||
anyhow = "1.0.44"
|
||||
csv = "1.1.6"
|
||||
criterion = "0.3.5"
|
||||
torch-sys = { version = "0.5.0", path = "E:/Coding/tch-rs/torch-sys" }
|
||||
torch-sys = "~0.6.1"
|
||||
tempfile = "3.2.0"
|
||||
|
@ -71,8 +71,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett
|
||||
|
||||
### Manual installation (recommended)
|
||||
|
||||
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.9.0`: if this version is no longer available on the "get started" page,
|
||||
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.9.0%2Bcu111.zip` for a Linux version with CUDA11.
|
||||
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
|
||||
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
|
||||
2. Extract the library to a location of your choice
|
||||
3. Set the following environment variables
|
||||
##### Linux:
|
||||
|
@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
|
@ -53,9 +53,9 @@ fn main() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
|
@ -56,9 +56,9 @@ fn main() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
|
||||
|
||||
for sentence in outputs {
|
||||
println!("{}", sentence);
|
||||
|
@ -1205,17 +1205,17 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -19,7 +19,7 @@ pub fn _mish(x: &Tensor) -> Tensor {
|
||||
}
|
||||
|
||||
pub fn _gelu_new(x: &Tensor) -> Tensor {
|
||||
x * 0.5 * (((x.pow(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
|
||||
x * 0.5 * (((x.pow_tensor_scalar(3.0f64) * 0.044715 + x) * ((2f64 / PI).sqrt())).tanh() + 1)
|
||||
}
|
||||
|
||||
pub fn _tanh(x: &Tensor) -> Tensor {
|
||||
|
@ -81,8 +81,8 @@
|
||||
//!
|
||||
//! ### Manual installation (recommended)
|
||||
//!
|
||||
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.8.1`: if this version is no longer available on the "get started" page,
|
||||
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.8.1%2Bcu111.zip` for a Linux version with CUDA11.
|
||||
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v1.10.0`: if this version is no longer available on the "get started" page,
|
||||
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu111/libtorch-shared-with-deps-1.10.0%2Bcu111.zip` for a Linux version with CUDA11.
|
||||
//! 2. Extract the library to a location of your choice
|
||||
//! 3. Set the following environment variables
|
||||
//! ##### Linux:
|
||||
|
@ -801,17 +801,17 @@ impl PrivateLanguageGenerator<M2M100ForConditionalGeneration, M2M100Vocab, M2M10
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -979,17 +979,17 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, T>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: T,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
T: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -1011,17 +1011,17 @@ impl PrivateLanguageGenerator<MBartForConditionalGeneration, MBart50Vocab, MBart
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -774,17 +774,17 @@ impl PrivateLanguageGenerator<PegasusForConditionalGeneration, PegasusVocab, Peg
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -494,13 +494,16 @@ impl TokenizerOption {
|
||||
}
|
||||
|
||||
/// Interface method
|
||||
pub fn encode_list(
|
||||
pub fn encode_list<S>(
|
||||
&self,
|
||||
text_list: &[&str],
|
||||
text_list: &[S],
|
||||
max_len: usize,
|
||||
truncation_strategy: &TruncationStrategy,
|
||||
stride: usize,
|
||||
) -> Vec<TokenizedInput> {
|
||||
) -> Vec<TokenizedInput>
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::encode_list(
|
||||
tokenizer,
|
||||
@ -809,7 +812,10 @@ impl TokenizerOption {
|
||||
}
|
||||
|
||||
/// Interface method to tokenization
|
||||
pub fn tokenize_list(&self, text: &[&str]) -> Vec<Vec<String>> {
|
||||
pub fn tokenize_list<S>(&self, text: &[S]) -> Vec<Vec<String>>
|
||||
where
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
|
||||
Self::Roberta(ref tokenizer) => MultiThreadedTokenizer::tokenize_list(tokenizer, text),
|
||||
@ -837,7 +843,7 @@ impl TokenizerOption {
|
||||
/// Interface method to decoding
|
||||
pub fn decode(
|
||||
&self,
|
||||
token_ids: Vec<i64>,
|
||||
token_ids: &[i64],
|
||||
skip_special_tokens: bool,
|
||||
clean_up_tokenization_spaces: bool,
|
||||
) -> String {
|
||||
@ -964,10 +970,9 @@ impl TokenizerOption {
|
||||
}
|
||||
|
||||
/// Interface method to convert tokens to ids
|
||||
pub fn convert_tokens_to_ids<S, ST>(&self, tokens: S) -> Vec<i64>
|
||||
pub fn convert_tokens_to_ids<S>(&self, tokens: &[S]) -> Vec<i64>
|
||||
where
|
||||
S: AsRef<[ST]>,
|
||||
ST: AsRef<str>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
match *self {
|
||||
Self::Bert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
|
||||
|
@ -416,22 +416,20 @@ impl Conversation {
|
||||
/// let _ = conversation_manager
|
||||
/// .get(&conversation_1_id)
|
||||
/// .unwrap()
|
||||
/// .load_from_history(history, encoded_history);
|
||||
/// .load_from_history(&history, &encoded_history);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn load_from_history<ST, SI, STR, SIN>(&mut self, texts: ST, ids: SI)
|
||||
pub fn load_from_history<S, SI>(&mut self, texts: &[S], ids: &[SI])
|
||||
where
|
||||
ST: AsRef<[STR]>,
|
||||
SI: AsRef<[SIN]>,
|
||||
STR: AsRef<str>,
|
||||
SIN: AsRef<[i64]>,
|
||||
S: AsRef<str>,
|
||||
SI: AsRef<[i64]>,
|
||||
{
|
||||
for (round_text, round_ids) in texts.as_ref().iter().zip(ids.as_ref().iter()) {
|
||||
for (round_text, round_ids) in texts.iter().zip(ids.iter()) {
|
||||
self.append(round_text.as_ref(), round_ids.as_ref());
|
||||
}
|
||||
|
||||
if texts.as_ref().len() / 2 == 1 {
|
||||
if texts.len() / 2 == 1 {
|
||||
self.history.pop();
|
||||
}
|
||||
}
|
||||
@ -840,11 +838,11 @@ impl ConversationModel {
|
||||
let generated_response = &generated_sequence[input_length - removed_padding.0..];
|
||||
conversation
|
||||
.generated_responses
|
||||
.push(self.model.get_tokenizer().decode(
|
||||
generated_response.to_vec(),
|
||||
true,
|
||||
true,
|
||||
));
|
||||
.push(
|
||||
self.model
|
||||
.get_tokenizer()
|
||||
.decode(generated_response, true, true),
|
||||
);
|
||||
conversation.history.push(conversation_promp_ids);
|
||||
conversation.history.push(generated_response.to_vec());
|
||||
conversation.mark_processed();
|
||||
|
@ -41,7 +41,7 @@
|
||||
//! let input_context = "The dog";
|
||||
//! let second_input_context = "The cat was";
|
||||
//! let output = gpt2_generator.generate(
|
||||
//! Some(vec![input_context, second_input_context]),
|
||||
//! Some(&[input_context, second_input_context]),
|
||||
//! None,
|
||||
//! min_length,
|
||||
//! max_length,
|
||||
@ -321,16 +321,16 @@ pub(crate) mod private_generation_utils {
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().tokenize_list(prompt_text.as_ref());
|
||||
let tokens = self._get_tokenizer().tokenize_list(prompt_text);
|
||||
let token_ids = tokens
|
||||
.into_iter()
|
||||
.map(|prompt_tokens| self._get_tokenizer().convert_tokens_to_ids(&prompt_tokens))
|
||||
@ -934,7 +934,7 @@ pub(crate) mod private_generation_utils {
|
||||
current_length += 1;
|
||||
}
|
||||
let scores_output = scores_output.map(|scores_tensor| {
|
||||
(scores_tensor / sentence_lengths.pow(gen_opt.length_penalty))
|
||||
(scores_tensor / sentence_lengths.pow_tensor_scalar(gen_opt.length_penalty))
|
||||
.iter::<f64>()
|
||||
.unwrap()
|
||||
.collect::<Vec<f64>>()
|
||||
@ -1504,7 +1504,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// }
|
||||
///
|
||||
/// let output = gpt2_generator.generate(
|
||||
/// Some(vec![input_context, second_input_context]),
|
||||
/// Some(&[input_context, second_input_context]),
|
||||
/// attention_mask,
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
@ -1530,9 +1530,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// ]
|
||||
/// # ;
|
||||
/// ```
|
||||
fn generate<'a, S>(
|
||||
fn generate<S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
prompt_texts: Option<&[S]>,
|
||||
attention_mask: Option<Tensor>,
|
||||
min_length: impl Into<Option<i64>>,
|
||||
max_length: impl Into<Option<i64>>,
|
||||
@ -1543,7 +1543,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
output_scores: bool,
|
||||
) -> Vec<GeneratedTextOutput>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let indices_outputs = self.generate_indices(
|
||||
prompt_texts,
|
||||
@ -1561,7 +1561,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
output.push(GeneratedTextOutput {
|
||||
text: self
|
||||
._get_tokenizer()
|
||||
.decode(generated_sequence.indices, true, true),
|
||||
.decode(&generated_sequence.indices, true, true),
|
||||
score: generated_sequence.score,
|
||||
});
|
||||
}
|
||||
@ -1636,7 +1636,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// }
|
||||
///
|
||||
/// let output = gpt2_generator.generate_indices(
|
||||
/// Some(vec![input_context, second_input_context]),
|
||||
/// Some(&[input_context, second_input_context]),
|
||||
/// attention_mask,
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
@ -1649,9 +1649,9 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
fn generate_indices<'a, S>(
|
||||
fn generate_indices<S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
prompt_texts: Option<&[S]>,
|
||||
attention_mask: Option<Tensor>,
|
||||
min_length: impl Into<Option<i64>>,
|
||||
max_length: impl Into<Option<i64>>,
|
||||
@ -1662,7 +1662,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
output_scores: bool,
|
||||
) -> Vec<GeneratedIndicesOutput>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let eos_token_ids = PrivateLanguageGenerator::get_eos_ids(self).clone();
|
||||
|
||||
@ -1771,7 +1771,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>:
|
||||
/// }
|
||||
///
|
||||
/// let output = gpt2_generator.generate_indices(
|
||||
/// Some(vec![input_context, second_input_context]),
|
||||
/// Some(&[input_context, second_input_context]),
|
||||
/// attention_mask,
|
||||
/// min_length,
|
||||
/// max_length,
|
||||
|
@ -191,9 +191,9 @@ impl NERModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<Entity>>
|
||||
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<Entity>>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
|
@ -195,9 +195,9 @@ impl POSModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn predict<'a, S>(&self, input: S) -> Vec<Vec<POSTag>>
|
||||
pub fn predict<S>(&self, input: &[S]) -> Vec<Vec<POSTag>>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
self.token_classification_model
|
||||
.predict(input, true, false)
|
||||
|
@ -254,13 +254,13 @@ impl SummarizationOption {
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate<'a, S>(
|
||||
pub fn generate<S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
prompt_texts: Option<&[S]>,
|
||||
attention_mask: Option<Tensor>,
|
||||
) -> Vec<String>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match *self {
|
||||
Self::Bart(ref model) => model
|
||||
@ -406,22 +406,18 @@ impl SummarizationModel {
|
||||
/// # }
|
||||
/// ```
|
||||
/// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
|
||||
pub fn summarize<'a, S>(&self, texts: S) -> Vec<String>
|
||||
pub fn summarize<S>(&self, texts: &[S]) -> Vec<String>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match &self.prefix {
|
||||
None => self.model.generate(Some(texts), None),
|
||||
Some(prefix) => {
|
||||
let texts = texts
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|text| format!("{}{}", prefix, text))
|
||||
.map(|text| format!("{}{}", prefix, text.as_ref()))
|
||||
.collect::<Vec<String>>();
|
||||
self.model.generate(
|
||||
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
|
||||
None,
|
||||
)
|
||||
self.model.generate(Some(&texts), None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -238,15 +238,15 @@ impl TextGenerationOption {
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate_indices<'a, S>(
|
||||
pub fn generate_indices<S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
prompt_texts: Option<&[S]>,
|
||||
attention_mask: Option<Tensor>,
|
||||
min_length: Option<i64>,
|
||||
max_length: Option<i64>,
|
||||
) -> Vec<Vec<i64>>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match *self {
|
||||
Self::GPT(ref model) => model
|
||||
@ -460,9 +460,9 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn generate<'a, S>(&self, texts: S, prefix: impl Into<Option<&'a str>>) -> Vec<String>
|
||||
pub fn generate<'a, S>(&self, texts: &[S], prefix: impl Into<Option<&'a str>>) -> Vec<String>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let (prefix, prefix_length) = match (prefix.into(), &self.prefix) {
|
||||
(Some(query_prefix), _) => (
|
||||
@ -478,10 +478,10 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
let texts = texts
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|text| format!("{} {}", prefix, text))
|
||||
.map(|text| format!("{} {}", prefix, text.as_ref()))
|
||||
.collect::<Vec<String>>();
|
||||
self.model.generate_indices(
|
||||
Some(texts.iter().map(|x| &**x).collect::<Vec<&str>>()),
|
||||
Some(&texts),
|
||||
None,
|
||||
Some(self.min_length + prefix_length),
|
||||
Some(self.max_length + prefix_length),
|
||||
@ -493,14 +493,7 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
|
||||
let mut output = Vec::with_capacity(generated_indices.len());
|
||||
for generated_sequence in generated_indices {
|
||||
output.push(self.model.get_tokenizer().decode(
|
||||
if prefix_length.is_some() {
|
||||
generated_sequence
|
||||
.into_iter()
|
||||
.skip(prefix_length.unwrap_or(0) as usize)
|
||||
.collect::<Vec<i64>>()
|
||||
} else {
|
||||
generated_sequence
|
||||
},
|
||||
&generated_sequence[prefix_length.unwrap_or(0) as usize..],
|
||||
true,
|
||||
true,
|
||||
));
|
||||
|
@ -776,17 +776,16 @@ impl TokenClassificationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn predict<'a, S>(
|
||||
pub fn predict<S>(
|
||||
&self,
|
||||
input: S,
|
||||
input: &[S],
|
||||
consolidate_sub_tokens: bool,
|
||||
return_special: bool,
|
||||
) -> Vec<Vec<Token>>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
let mut features: Vec<InputFeature> = input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(example_index, example)| self.generate_features(example, example_index))
|
||||
@ -794,7 +793,7 @@ impl TokenClassificationModel {
|
||||
.collect();
|
||||
|
||||
let mut example_tokens_map: HashMap<usize, Vec<Token>> = HashMap::new();
|
||||
for example_idx in 0..input.as_ref().len() {
|
||||
for example_idx in 0..input.len() {
|
||||
example_tokens_map.insert(example_idx, Vec::new());
|
||||
}
|
||||
let mut start = 0usize;
|
||||
@ -820,7 +819,8 @@ impl TokenClassificationModel {
|
||||
let labels = label_indices.get(sentence_idx);
|
||||
let feature = &features[sentence_idx as usize];
|
||||
let sentence_reference_flag = &feature.reference_feature;
|
||||
let original_chars = input.as_ref()[feature.example_index]
|
||||
let original_chars = input[feature.example_index]
|
||||
.as_ref()
|
||||
.chars()
|
||||
.collect::<Vec<char>>();
|
||||
let mut word_idx: u16 = 0;
|
||||
@ -935,19 +935,19 @@ impl TokenClassificationModel {
|
||||
let text = match offsets {
|
||||
None => match self.tokenizer {
|
||||
TokenizerOption::Bert(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
Tokenizer::decode(tokenizer, &[token_id], false, false)
|
||||
}
|
||||
TokenizerOption::Roberta(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
Tokenizer::decode(tokenizer, &[token_id], false, false)
|
||||
}
|
||||
TokenizerOption::XLMRoberta(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
Tokenizer::decode(tokenizer, &[token_id], false, false)
|
||||
}
|
||||
TokenizerOption::Albert(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
Tokenizer::decode(tokenizer, &[token_id], false, false)
|
||||
}
|
||||
TokenizerOption::XLNet(ref tokenizer) => {
|
||||
Tokenizer::decode(tokenizer, vec![token_id], false, false)
|
||||
Tokenizer::decode(tokenizer, &[token_id], false, false)
|
||||
}
|
||||
_ => panic!(
|
||||
"Token classification not implemented for {:?}!",
|
||||
|
@ -55,9 +55,13 @@
|
||||
//! let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
//!
|
||||
//! let mut outputs = Vec::new();
|
||||
//! outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
//! outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
//! outputs.extend(model.translate(
|
||||
//! &[source_sentence],
|
||||
//! Language::English,
|
||||
//! Language::Spanish,
|
||||
//! )?);
|
||||
//! outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
//!
|
||||
//! for sentence in outputs {
|
||||
//! println!("{}", sentence);
|
||||
|
@ -364,7 +364,7 @@ impl TranslationModelBuilder {
|
||||
{
|
||||
match (source_languages.as_slice(), target_languages.as_slice()) {
|
||||
([Language::English], [Language::German]) => {
|
||||
get_marian_resources!(ENGLISH2RUSSIAN)
|
||||
get_marian_resources!(ENGLISH2GERMAN)
|
||||
}
|
||||
([Language::English], [Language::Russian]) => {
|
||||
get_marian_resources!(ENGLISH2RUSSIAN)
|
||||
|
@ -586,8 +586,7 @@ impl TranslationOption {
|
||||
if !supported_source_languages.contains(source_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
source_language.to_string(),
|
||||
supported_source_languages
|
||||
source_language, supported_source_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -596,8 +595,7 @@ impl TranslationOption {
|
||||
if !supported_target_languages.contains(target_language) {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
"{} not in list of supported languages: {:?}",
|
||||
target_language.to_string(),
|
||||
supported_target_languages
|
||||
target_language, supported_target_languages
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -630,7 +628,7 @@ impl TranslationOption {
|
||||
Some(format!(
|
||||
"translate {} to {}:",
|
||||
match source_language {
|
||||
Some(value) => value.to_string(),
|
||||
Some(value) => value,
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Missing source language for T5".to_string(),
|
||||
@ -638,7 +636,7 @@ impl TranslationOption {
|
||||
}
|
||||
},
|
||||
match target_language {
|
||||
Some(value) => value.to_string(),
|
||||
Some(value) => value,
|
||||
None => {
|
||||
return Err(RustBertError::ValueError(
|
||||
"Missing target language for T5".to_string(),
|
||||
@ -665,7 +663,7 @@ impl TranslationOption {
|
||||
)),
|
||||
if let Some(target_language) = target_language {
|
||||
Some(
|
||||
model._get_tokenizer().convert_tokens_to_ids([format!(
|
||||
model._get_tokenizer().convert_tokens_to_ids(&[format!(
|
||||
">>{}<<",
|
||||
target_language.get_iso_639_1_code()
|
||||
)])[0],
|
||||
@ -705,9 +703,8 @@ impl TranslationOption {
|
||||
if let Some(target_language) = target_language {
|
||||
let language_code = target_language.get_iso_639_1_code();
|
||||
Some(
|
||||
model
|
||||
._get_tokenizer()
|
||||
.convert_tokens_to_ids([match language_code.len() {
|
||||
model._get_tokenizer().convert_tokens_to_ids(&[
|
||||
match language_code.len() {
|
||||
2 => format!(">>{}.<<", language_code),
|
||||
3 => format!(">>{}<<", language_code),
|
||||
_ => {
|
||||
@ -715,7 +712,8 @@ impl TranslationOption {
|
||||
"Invalid ISO 639-3 code".to_string(),
|
||||
));
|
||||
}
|
||||
}])[0],
|
||||
},
|
||||
])[0],
|
||||
)
|
||||
} else {
|
||||
return Err(RustBertError::ValueError(format!(
|
||||
@ -730,14 +728,14 @@ impl TranslationOption {
|
||||
}
|
||||
|
||||
/// Interface method to generate() of the particular models.
|
||||
pub fn generate<'a, S>(
|
||||
pub fn generate<S>(
|
||||
&self,
|
||||
prompt_texts: Option<S>,
|
||||
prompt_texts: Option<&[S]>,
|
||||
attention_mask: Option<Tensor>,
|
||||
forced_bos_token_id: Option<i64>,
|
||||
) -> Vec<String>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
match *self {
|
||||
Self::Marian(ref model) => model
|
||||
@ -927,14 +925,14 @@ impl TranslationModel {
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn translate<'a, S>(
|
||||
pub fn translate<S>(
|
||||
&self,
|
||||
texts: S,
|
||||
texts: &[S],
|
||||
source_language: impl Into<Option<Language>>,
|
||||
target_language: impl Into<Option<Language>>,
|
||||
) -> Result<Vec<String>, RustBertError>
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
|
||||
source_language.into().as_ref(),
|
||||
@ -946,15 +944,10 @@ impl TranslationModel {
|
||||
Ok(match prefix {
|
||||
Some(value) => {
|
||||
let texts = texts
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|&v| format!("{}{}", value, v))
|
||||
.map(|v| format!("{}{}", value, v.as_ref()))
|
||||
.collect::<Vec<String>>();
|
||||
self.model.generate(
|
||||
Some(texts.iter().map(AsRef::as_ref).collect::<Vec<&str>>()),
|
||||
None,
|
||||
forced_bos_token_id,
|
||||
)
|
||||
self.model.generate(Some(&texts), None, forced_bos_token_id)
|
||||
}
|
||||
None => self.model.generate(Some(texts), None, forced_bos_token_id),
|
||||
})
|
||||
|
@ -1063,17 +1063,17 @@ impl
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -33,9 +33,9 @@ impl T5LayerNorm {
|
||||
impl Module for T5LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Tensor {
|
||||
let input_type = x.kind();
|
||||
let variance = x
|
||||
.to_kind(Kind::Float)
|
||||
.pow(2.0_f64)
|
||||
let variance =
|
||||
x.to_kind(Kind::Float)
|
||||
.pow_tensor_scalar(2.0_f64)
|
||||
.mean_dim(&[-1], true, Kind::Float);
|
||||
let x = x * (variance + self.epsilon).rsqrt();
|
||||
if input_type != Kind::Float {
|
||||
|
@ -858,17 +858,17 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_prompt_text<'a, S>(
|
||||
fn encode_prompt_text<S>(
|
||||
&self,
|
||||
prompt_text: S,
|
||||
prompt_text: &[S],
|
||||
max_len: i64,
|
||||
pad_token_id: Option<i64>,
|
||||
) -> Tensor
|
||||
where
|
||||
S: AsRef<[&'a str]>,
|
||||
S: AsRef<str> + Sync,
|
||||
{
|
||||
let tokens = self._get_tokenizer().encode_list(
|
||||
prompt_text.as_ref(),
|
||||
prompt_text,
|
||||
max_len as usize,
|
||||
&TruncationStrategy::LongestFirst,
|
||||
0,
|
||||
|
@ -238,7 +238,6 @@ impl XLNetRelativeAttention {
|
||||
target_mapping: Option<&Tensor>,
|
||||
train: bool,
|
||||
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<Tensor>) {
|
||||
if let Some(g) = g {
|
||||
let cat_value = if let Some(mems) = &layer_state {
|
||||
if mems.prev_content.size().len() > 1 {
|
||||
Some(Tensor::cat(&[&mems.prev_content, h], 0))
|
||||
@ -252,7 +251,6 @@ impl XLNetRelativeAttention {
|
||||
Some(value) => value,
|
||||
None => h,
|
||||
};
|
||||
|
||||
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
|
||||
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
|
||||
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
|
||||
@ -267,8 +265,9 @@ impl XLNetRelativeAttention {
|
||||
attn_mask_h,
|
||||
train,
|
||||
);
|
||||
|
||||
let output_h = self.post_attention(h, &attention_vec_h, true, train);
|
||||
|
||||
let (output_g, attention_probas_g) = if let Some(g) = g {
|
||||
let q_head_g = Tensor::einsum("ibh,hnd->ibnd", &[g, &self.query]);
|
||||
|
||||
let (attention_vec_g, attention_probas_g) = match target_mapping {
|
||||
@ -299,44 +298,10 @@ impl XLNetRelativeAttention {
|
||||
};
|
||||
|
||||
let output_g = self.post_attention(g, &attention_vec_g, true, train);
|
||||
(
|
||||
output_h,
|
||||
Some(output_g),
|
||||
attention_probas_h,
|
||||
attention_probas_g,
|
||||
)
|
||||
(Some(output_g), attention_probas_g)
|
||||
} else {
|
||||
let cat_value = if let Some(mems) = &layer_state {
|
||||
if mems.prev_content.size().len() > 1 {
|
||||
Some(Tensor::cat(&[&mems.prev_content, h], 0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
(None, None)
|
||||
};
|
||||
let cat = match &cat_value {
|
||||
Some(value) => value,
|
||||
None => h,
|
||||
};
|
||||
|
||||
let q_head_h = Tensor::einsum("ibh,hnd->ibnd", &[h, &self.query]);
|
||||
let k_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.key]);
|
||||
let v_head_h = Tensor::einsum("ibh,hnd->ibnd", &[cat, &self.value]);
|
||||
let k_head_r = Tensor::einsum("ibh,hnd->ibnd", &[r, &self.pos]);
|
||||
|
||||
let (attention_vec, attention_probas) = self.rel_attention_core(
|
||||
&q_head_h,
|
||||
&k_head_h,
|
||||
&v_head_h,
|
||||
&k_head_r,
|
||||
seg_mat,
|
||||
attn_mask_h,
|
||||
train,
|
||||
);
|
||||
|
||||
let output_h = self.post_attention(h, &attention_vec, true, train);
|
||||
(output_h, None, attention_probas, None)
|
||||
}
|
||||
(output_h, output_g, attention_probas_h, attention_probas_g)
|
||||
}
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ fn distilgpt2_lm_model() -> anyhow::Result<()> {
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
let next_word = tokenizer.decode(&[next_word_id], true, true);
|
||||
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 11, 50257));
|
||||
match model_output.cache {
|
||||
|
@ -82,7 +82,7 @@ fn gpt2_lm_model() -> anyhow::Result<()> {
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
let next_word = tokenizer.decode(&[next_word_id], true, true);
|
||||
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 4, 50257));
|
||||
match model_output.cache {
|
||||
|
@ -72,7 +72,7 @@ fn gpt_neo_lm() -> anyhow::Result<()> {
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
let next_word = tokenizer.decode(&[next_word_id], true, true);
|
||||
let next_score = model_output
|
||||
.lm_logits
|
||||
.get(0)
|
||||
|
@ -274,7 +274,7 @@ fn longformer_for_multiple_choice() -> anyhow::Result<()> {
|
||||
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
|
||||
let inputs = ["Very positive sentence", "Second sentence input"];
|
||||
let tokenized_input = tokenizer.encode_pair_list(
|
||||
inputs
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&inp| (prompt, inp))
|
||||
.collect::<Vec<(&str, &str)>>(),
|
||||
|
@ -107,9 +107,9 @@ fn m2m100_translation() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
|
@ -73,9 +73,9 @@ fn mbart_translation() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Hindi)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
|
@ -182,7 +182,7 @@ fn mobilebert_for_multiple_choice() -> anyhow::Result<()> {
|
||||
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
|
||||
let inputs = ["Very positive sentence", "Second sentence input"];
|
||||
let tokenized_input = tokenizer.encode_pair_list(
|
||||
inputs
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&inp| (prompt, inp))
|
||||
.collect::<Vec<(&str, &str)>>(),
|
||||
|
@ -83,7 +83,7 @@ fn openai_gpt_lm_model() -> anyhow::Result<()> {
|
||||
.get(-1)
|
||||
.argmax(-1, true)
|
||||
.int64_value(&[0]);
|
||||
let next_word = tokenizer.decode(vec![next_word_id], true, true);
|
||||
let next_word = tokenizer.decode(&[next_word_id], true, true);
|
||||
|
||||
assert_eq!(model_output.lm_logits.size(), vec!(1, 6, 40478));
|
||||
assert!(
|
||||
|
@ -44,9 +44,9 @@ fn test_translation_t5() -> anyhow::Result<()> {
|
||||
let source_sentence = "This sentence will be translated in multiple languages.";
|
||||
|
||||
let mut outputs = Vec::new();
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate([source_sentence], Language::English, Language::Romanian)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::German)?);
|
||||
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Romanian)?);
|
||||
|
||||
assert_eq!(outputs.len(), 3);
|
||||
assert_eq!(
|
||||
|
@ -331,7 +331,7 @@ fn xlnet_for_multiple_choice() -> anyhow::Result<()> {
|
||||
let prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.";
|
||||
let inputs = ["Very positive sentence", "Second sentence input"];
|
||||
let tokenized_input = tokenizer.encode_pair_list(
|
||||
inputs
|
||||
&inputs
|
||||
.iter()
|
||||
.map(|&inp| (prompt, inp))
|
||||
.collect::<Vec<(&str, &str)>>(),
|
||||
|
Loading…
Reference in New Issue
Block a user