Updated translation (clippy)

This commit is contained in:
Guillaume B 2020-09-13 11:33:27 +02:00
parent 6cef45787d
commit b42cc60409
6 changed files with 235 additions and 365 deletions

View File

@ -160,7 +160,7 @@ impl QaExample {
}
if !current_word.is_empty() {
doc_tokens.push(current_word.clone());
doc_tokens.push(current_word);
}
(doc_tokens, char_to_word_offset)
}
@ -437,12 +437,9 @@ impl QuestionAnsweringModel {
let mut var_store = VarStore::new(device);
let mut model_config =
ConfigOption::from_file(question_answering_config.model_type, config_path);
match model_config {
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
ConfigOption::DistilBert(ref mut config) => {
config.sinusoidal_pos_embds = false;
}
_ => (),
if let ConfigOption::DistilBert(ref mut config) = model_config {
config.sinusoidal_pos_embds = false;
};
let qa_model = QuestionAnsweringOption::new(
@ -605,7 +602,7 @@ impl QuestionAnsweringModel {
feature_id_start = max_feature_id;
let example_answers = example_top_k_answers_map
.entry(example_id)
.or_insert(vec![]);
.or_insert_with(Vec::new);
example_answers.extend(answers);
}
});
@ -792,8 +789,8 @@ impl QuestionAnsweringModel {
fn encode_qa_pair(
&self,
truncated_query: &Vec<i64>,
spans_token_ids: &Vec<i64>,
truncated_query: &[i64],
spans_token_ids: &[i64],
max_seq_length: usize,
doc_stride: usize,
sequence_pair_added_tokens: usize,
@ -809,8 +806,8 @@ impl QuestionAnsweringModel {
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
truncate_sequences(
truncated_query.clone(),
Some(spans_token_ids.clone()),
truncated_query.into(),
Some(spans_token_ids.into()),
vec![],
None,
vec![],

View File

@ -569,7 +569,7 @@ impl SequenceClassificationModel {
};
sequence_labels.push(label);
}
if sequence_labels.len() > 0 {
if !sequence_labels.is_empty() {
labels.push(sequence_labels);
}
Ok(labels)

View File

@ -632,7 +632,7 @@ impl TokenClassificationModel {
fn decode_token(
&self,
original_sentence_chars: &Vec<char>,
original_sentence_chars: &[char],
sentence_tokens: &TokenizedInput,
input_tensor: &Tensor,
labels: &Tensor,
@ -700,10 +700,10 @@ impl TokenClassificationModel {
label_aggregation_function: &LabelAggregationOption,
) {
let mut tokens_to_replace = vec![];
let mut token_iter = tokens.iter_consolidate_tokens();
let token_iter = tokens.iter_consolidate_tokens();
let mut cursor = 0;
while let Some(sub_tokens) = token_iter.next() {
for sub_tokens in token_iter {
if sub_tokens.len() > 1 {
let (label_index, label) =
self.consolidate_labels(sub_tokens, label_aggregation_function);
@ -718,14 +718,15 @@ impl TokenClassificationModel {
Some(offset) => Some(offset.end),
None => None,
};
let offset = if offset_start.is_some() & offset_end.is_some() {
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
} else {
None
};
let offset =
if let (Some(offset_start), Some(offset_end)) = (offset_start, offset_end) {
Some(Offset::new(offset_start, offset_end))
} else {
None
};
let mut text = String::new();
let mut score = 1f64;
for current_sub_token in sub_tokens.into_iter() {
for current_sub_token in sub_tokens.iter() {
text.push_str(current_sub_token.text.as_str());
score *= if current_sub_token.label_index == label_index {
current_sub_token.score

View File

@ -91,311 +91,176 @@ pub enum Language {
GermanToFrench,
}
struct RemoteTranslationResources;
struct RemoteTranslationResources {
model_resource: (&'static str, &'static str),
config_resource: (&'static str, &'static str),
vocab_resource: (&'static str, &'static str),
merges_resource: (&'static str, &'static str),
prefix: Option<&'static str>,
model_type: ModelType,
}
impl RemoteTranslationResources {
pub const ENGLISH2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2FRENCH,
ModelType::T5,
);
pub const ENGLISH2GERMAN_V2: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
T5ModelResources::T5_BASE,
T5ConfigResources::T5_BASE,
T5VocabResources::T5_BASE,
T5VocabResources::T5_BASE,
T5Prefix::ENGLISH2GERMAN,
ModelType::T5,
);
pub const ENGLISH2CATALAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2CATALAN,
ModelType::Marian,
);
pub const ENGLISH2SPANISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2SPANISH,
ModelType::Marian,
);
pub const ENGLISH2PORTUGUESE: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2PORTUGUESE,
ModelType::Marian,
);
pub const ENGLISH2ITALIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ITALIAN,
ModelType::Marian,
);
pub const ENGLISH2ROMANIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2ROMANCE,
MarianConfigResources::ENGLISH2ROMANCE,
MarianVocabResources::ENGLISH2ROMANCE,
MarianSpmResources::ENGLISH2ROMANCE,
MarianPrefix::ENGLISH2ROMANIAN,
ModelType::Marian,
);
pub const ENGLISH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2GERMAN,
MarianConfigResources::ENGLISH2GERMAN,
MarianVocabResources::ENGLISH2GERMAN,
MarianSpmResources::ENGLISH2GERMAN,
MarianPrefix::ENGLISH2GERMAN,
ModelType::Marian,
);
pub const ENGLISH2RUSSIAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ENGLISH2RUSSIAN,
MarianConfigResources::ENGLISH2RUSSIAN,
MarianVocabResources::ENGLISH2RUSSIAN,
MarianSpmResources::ENGLISH2RUSSIAN,
MarianPrefix::ENGLISH2RUSSIAN,
ModelType::Marian,
);
pub const FRENCH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::FRENCH2ENGLISH,
ModelType::Marian,
);
pub const CATALAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::CATALAN2ENGLISH,
ModelType::Marian,
);
pub const SPANISH2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::SPANISH2ENGLISH,
ModelType::Marian,
);
pub const PORTUGUESE2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::PORTUGUESE2ENGLISH,
ModelType::Marian,
);
pub const ITALIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ITALIAN2ENGLISH,
ModelType::Marian,
);
pub const ROMANIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::ROMANCE2ENGLISH,
MarianConfigResources::ROMANCE2ENGLISH,
MarianVocabResources::ROMANCE2ENGLISH,
MarianSpmResources::ROMANCE2ENGLISH,
MarianPrefix::ROMANIAN2ENGLISH,
ModelType::Marian,
);
pub const GERMAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2ENGLISH,
MarianConfigResources::GERMAN2ENGLISH,
MarianVocabResources::GERMAN2ENGLISH,
MarianSpmResources::GERMAN2ENGLISH,
MarianPrefix::GERMAN2ENGLISH,
ModelType::Marian,
);
pub const RUSSIAN2ENGLISH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::RUSSIAN2ENGLISH,
MarianConfigResources::RUSSIAN2ENGLISH,
MarianVocabResources::RUSSIAN2ENGLISH,
MarianSpmResources::RUSSIAN2ENGLISH,
MarianPrefix::RUSSIAN2ENGLISH,
ModelType::Marian,
);
pub const FRENCH2GERMAN: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::FRENCH2GERMAN,
MarianConfigResources::FRENCH2GERMAN,
MarianVocabResources::FRENCH2GERMAN,
MarianSpmResources::FRENCH2GERMAN,
MarianPrefix::FRENCH2GERMAN,
ModelType::Marian,
);
pub const GERMAN2FRENCH: (
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
(&'static str, &'static str),
Option<&'static str>,
ModelType,
) = (
MarianModelResources::GERMAN2FRENCH,
MarianConfigResources::GERMAN2FRENCH,
MarianVocabResources::GERMAN2FRENCH,
MarianSpmResources::GERMAN2FRENCH,
MarianPrefix::GERMAN2FRENCH,
ModelType::Marian,
);
pub const ENGLISH2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2FRENCH,
model_type: ModelType::Marian,
};
pub const ENGLISH2FRENCH_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2FRENCH,
model_type: ModelType::T5,
};
pub const ENGLISH2GERMAN_V2: RemoteTranslationResources = Self {
model_resource: T5ModelResources::T5_BASE,
config_resource: T5ConfigResources::T5_BASE,
vocab_resource: T5VocabResources::T5_BASE,
merges_resource: T5VocabResources::T5_BASE,
prefix: T5Prefix::ENGLISH2GERMAN,
model_type: ModelType::T5,
};
pub const ENGLISH2CATALAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2CATALAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2SPANISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2SPANISH,
model_type: ModelType::Marian,
};
pub const ENGLISH2PORTUGUESE: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2PORTUGUESE,
model_type: ModelType::Marian,
};
pub const ENGLISH2ITALIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ITALIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2ROMANIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2ROMANCE,
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
prefix: MarianPrefix::ENGLISH2ROMANIAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2GERMAN,
config_resource: MarianConfigResources::ENGLISH2GERMAN,
vocab_resource: MarianVocabResources::ENGLISH2GERMAN,
merges_resource: MarianSpmResources::ENGLISH2GERMAN,
prefix: MarianPrefix::ENGLISH2GERMAN,
model_type: ModelType::Marian,
};
pub const ENGLISH2RUSSIAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ENGLISH2RUSSIAN,
config_resource: MarianConfigResources::ENGLISH2RUSSIAN,
vocab_resource: MarianVocabResources::ENGLISH2RUSSIAN,
merges_resource: MarianSpmResources::ENGLISH2RUSSIAN,
prefix: MarianPrefix::ENGLISH2RUSSIAN,
model_type: ModelType::Marian,
};
pub const FRENCH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::FRENCH2ENGLISH,
model_type: ModelType::Marian,
};
pub const CATALAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::CATALAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const SPANISH2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::SPANISH2ENGLISH,
model_type: ModelType::Marian,
};
pub const PORTUGUESE2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::PORTUGUESE2ENGLISH,
model_type: ModelType::Marian,
};
pub const ITALIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ITALIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const ROMANIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::ROMANCE2ENGLISH,
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
prefix: MarianPrefix::ROMANIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const GERMAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2ENGLISH,
config_resource: MarianConfigResources::GERMAN2ENGLISH,
vocab_resource: MarianVocabResources::GERMAN2ENGLISH,
merges_resource: MarianSpmResources::GERMAN2ENGLISH,
prefix: MarianPrefix::GERMAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const RUSSIAN2ENGLISH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::RUSSIAN2ENGLISH,
config_resource: MarianConfigResources::RUSSIAN2ENGLISH,
vocab_resource: MarianVocabResources::RUSSIAN2ENGLISH,
merges_resource: MarianSpmResources::RUSSIAN2ENGLISH,
prefix: MarianPrefix::RUSSIAN2ENGLISH,
model_type: ModelType::Marian,
};
pub const FRENCH2GERMAN: RemoteTranslationResources = Self {
model_resource: MarianModelResources::FRENCH2GERMAN,
config_resource: MarianConfigResources::FRENCH2GERMAN,
vocab_resource: MarianVocabResources::FRENCH2GERMAN,
merges_resource: MarianSpmResources::FRENCH2GERMAN,
prefix: MarianPrefix::FRENCH2GERMAN,
model_type: ModelType::Marian,
};
pub const GERMAN2FRENCH: RemoteTranslationResources = Self {
model_resource: MarianModelResources::GERMAN2FRENCH,
config_resource: MarianConfigResources::GERMAN2FRENCH,
vocab_resource: MarianVocabResources::GERMAN2FRENCH,
merges_resource: MarianSpmResources::GERMAN2FRENCH,
prefix: MarianPrefix::GERMAN2FRENCH,
model_type: ModelType::Marian,
};
}
/// # Configuration for text translation
@ -463,37 +328,44 @@ impl TranslationConfig {
/// # }
/// ```
pub fn new(language: Language, device: Device) -> TranslationConfig {
let (model_resource, config_resource, vocab_resource, merges_resource, prefix, model_type) =
match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
let translation_resource = match language {
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
let prefix = match prefix {
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
};
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.model_resource,
));
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.config_resource,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.vocab_resource,
));
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
translation_resource.merges_resource,
));
let prefix = match translation_resource.prefix {
Some(value) => Some(value.to_string()),
None => None,
};
@ -516,7 +388,7 @@ impl TranslationConfig {
num_return_sequences: 1,
device,
prefix,
model_type,
model_type: translation_resource.model_type,
}
}
@ -740,10 +612,10 @@ impl TranslationModel {
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
match &self.prefix {
Some(value) => {
let texts: Vec<String> = texts
.into_iter()
let texts = texts
.iter()
.map(|&v| format!("{} {}", value, v))
.collect();
.collect::<Vec<String>>();
self.model
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
}

View File

@ -447,13 +447,13 @@ impl ZeroShotClassificationModel {
let label_sentences: Vec<String> = match template {
Some(function) => labels.iter().map(|label| function(label)).collect(),
None => labels
.into_iter()
.iter()
.map(|label| format!("This example is about {}.", label))
.collect(),
};
let text_pair_list = inputs
.into_iter()
.iter()
.cartesian_product(label_sentences.iter())
.map(|(&s, label)| (s, label.as_str()))
.collect();

View File

@ -198,18 +198,18 @@ impl T5Attention {
let length = temp_value.size()[2];
temp_value = temp_value.slice(2, length - 1, length, 1);
};
if attention_mask.is_some() {
temp_value = temp_value + attention_mask.unwrap();
if let Some(attention_mask) = attention_mask {
temp_value += attention_mask;
};
Some(temp_value)
} else {
None
};
let position_bias = if position_bias.is_none() {
calculated_position_bias.as_ref().unwrap()
let position_bias = if let Some(position_bias) = position_bias {
position_bias
} else {
position_bias.unwrap()
calculated_position_bias.as_ref().unwrap()
};
scores += position_bias;
@ -247,7 +247,7 @@ impl T5Attention {
let mut num_buckets = num_buckets;
let mut ret = n.zeros_like();
let n = if bidirectional {
num_buckets = num_buckets / 2;
num_buckets /= 2;
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
n.abs()
} else {