mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-11 12:55:34 +03:00
Updated translation (clippy)
This commit is contained in:
parent
6cef45787d
commit
b42cc60409
@ -160,7 +160,7 @@ impl QaExample {
|
||||
}
|
||||
|
||||
if !current_word.is_empty() {
|
||||
doc_tokens.push(current_word.clone());
|
||||
doc_tokens.push(current_word);
|
||||
}
|
||||
(doc_tokens, char_to_word_offset)
|
||||
}
|
||||
@ -437,12 +437,9 @@ impl QuestionAnsweringModel {
|
||||
let mut var_store = VarStore::new(device);
|
||||
let mut model_config =
|
||||
ConfigOption::from_file(question_answering_config.model_type, config_path);
|
||||
match model_config {
|
||||
// The config for the current pre-trained question answering model indicates position embeddings which does not seem accurate
|
||||
ConfigOption::DistilBert(ref mut config) => {
|
||||
config.sinusoidal_pos_embds = false;
|
||||
}
|
||||
_ => (),
|
||||
|
||||
if let ConfigOption::DistilBert(ref mut config) = model_config {
|
||||
config.sinusoidal_pos_embds = false;
|
||||
};
|
||||
|
||||
let qa_model = QuestionAnsweringOption::new(
|
||||
@ -605,7 +602,7 @@ impl QuestionAnsweringModel {
|
||||
feature_id_start = max_feature_id;
|
||||
let example_answers = example_top_k_answers_map
|
||||
.entry(example_id)
|
||||
.or_insert(vec![]);
|
||||
.or_insert_with(Vec::new);
|
||||
example_answers.extend(answers);
|
||||
}
|
||||
});
|
||||
@ -792,8 +789,8 @@ impl QuestionAnsweringModel {
|
||||
|
||||
fn encode_qa_pair(
|
||||
&self,
|
||||
truncated_query: &Vec<i64>,
|
||||
spans_token_ids: &Vec<i64>,
|
||||
truncated_query: &[i64],
|
||||
spans_token_ids: &[i64],
|
||||
max_seq_length: usize,
|
||||
doc_stride: usize,
|
||||
sequence_pair_added_tokens: usize,
|
||||
@ -809,8 +806,8 @@ impl QuestionAnsweringModel {
|
||||
|
||||
let (truncated_query, truncated_context, _, _, _, _, _, _, overflowing_tokens, _) =
|
||||
truncate_sequences(
|
||||
truncated_query.clone(),
|
||||
Some(spans_token_ids.clone()),
|
||||
truncated_query.into(),
|
||||
Some(spans_token_ids.into()),
|
||||
vec![],
|
||||
None,
|
||||
vec![],
|
||||
|
@ -569,7 +569,7 @@ impl SequenceClassificationModel {
|
||||
};
|
||||
sequence_labels.push(label);
|
||||
}
|
||||
if sequence_labels.len() > 0 {
|
||||
if !sequence_labels.is_empty() {
|
||||
labels.push(sequence_labels);
|
||||
}
|
||||
Ok(labels)
|
||||
|
@ -632,7 +632,7 @@ impl TokenClassificationModel {
|
||||
|
||||
fn decode_token(
|
||||
&self,
|
||||
original_sentence_chars: &Vec<char>,
|
||||
original_sentence_chars: &[char],
|
||||
sentence_tokens: &TokenizedInput,
|
||||
input_tensor: &Tensor,
|
||||
labels: &Tensor,
|
||||
@ -700,10 +700,10 @@ impl TokenClassificationModel {
|
||||
label_aggregation_function: &LabelAggregationOption,
|
||||
) {
|
||||
let mut tokens_to_replace = vec![];
|
||||
let mut token_iter = tokens.iter_consolidate_tokens();
|
||||
let token_iter = tokens.iter_consolidate_tokens();
|
||||
let mut cursor = 0;
|
||||
|
||||
while let Some(sub_tokens) = token_iter.next() {
|
||||
for sub_tokens in token_iter {
|
||||
if sub_tokens.len() > 1 {
|
||||
let (label_index, label) =
|
||||
self.consolidate_labels(sub_tokens, label_aggregation_function);
|
||||
@ -718,14 +718,15 @@ impl TokenClassificationModel {
|
||||
Some(offset) => Some(offset.end),
|
||||
None => None,
|
||||
};
|
||||
let offset = if offset_start.is_some() & offset_end.is_some() {
|
||||
Some(Offset::new(offset_start.unwrap(), offset_end.unwrap()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let offset =
|
||||
if let (Some(offset_start), Some(offset_end)) = (offset_start, offset_end) {
|
||||
Some(Offset::new(offset_start, offset_end))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut text = String::new();
|
||||
let mut score = 1f64;
|
||||
for current_sub_token in sub_tokens.into_iter() {
|
||||
for current_sub_token in sub_tokens.iter() {
|
||||
text.push_str(current_sub_token.text.as_str());
|
||||
score *= if current_sub_token.label_index == label_index {
|
||||
current_sub_token.score
|
||||
|
@ -91,311 +91,176 @@ pub enum Language {
|
||||
GermanToFrench,
|
||||
}
|
||||
|
||||
struct RemoteTranslationResources;
|
||||
struct RemoteTranslationResources {
|
||||
model_resource: (&'static str, &'static str),
|
||||
config_resource: (&'static str, &'static str),
|
||||
vocab_resource: (&'static str, &'static str),
|
||||
merges_resource: (&'static str, &'static str),
|
||||
prefix: Option<&'static str>,
|
||||
model_type: ModelType,
|
||||
}
|
||||
|
||||
impl RemoteTranslationResources {
|
||||
pub const ENGLISH2FRENCH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2FRENCH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2FRENCH_V2: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
T5ModelResources::T5_BASE,
|
||||
T5ConfigResources::T5_BASE,
|
||||
T5VocabResources::T5_BASE,
|
||||
T5VocabResources::T5_BASE,
|
||||
T5Prefix::ENGLISH2FRENCH,
|
||||
ModelType::T5,
|
||||
);
|
||||
pub const ENGLISH2GERMAN_V2: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
T5ModelResources::T5_BASE,
|
||||
T5ConfigResources::T5_BASE,
|
||||
T5VocabResources::T5_BASE,
|
||||
T5VocabResources::T5_BASE,
|
||||
T5Prefix::ENGLISH2GERMAN,
|
||||
ModelType::T5,
|
||||
);
|
||||
pub const ENGLISH2CATALAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2CATALAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2SPANISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2SPANISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2PORTUGUESE: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2PORTUGUESE,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2ITALIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2ITALIAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2ROMANIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2ROMANCE,
|
||||
MarianConfigResources::ENGLISH2ROMANCE,
|
||||
MarianVocabResources::ENGLISH2ROMANCE,
|
||||
MarianSpmResources::ENGLISH2ROMANCE,
|
||||
MarianPrefix::ENGLISH2ROMANIAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2GERMAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2GERMAN,
|
||||
MarianConfigResources::ENGLISH2GERMAN,
|
||||
MarianVocabResources::ENGLISH2GERMAN,
|
||||
MarianSpmResources::ENGLISH2GERMAN,
|
||||
MarianPrefix::ENGLISH2GERMAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2RUSSIAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ENGLISH2RUSSIAN,
|
||||
MarianConfigResources::ENGLISH2RUSSIAN,
|
||||
MarianVocabResources::ENGLISH2RUSSIAN,
|
||||
MarianSpmResources::ENGLISH2RUSSIAN,
|
||||
MarianPrefix::ENGLISH2RUSSIAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
|
||||
pub const FRENCH2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::FRENCH2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const CATALAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::CATALAN2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const SPANISH2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::SPANISH2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const PORTUGUESE2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::PORTUGUESE2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ITALIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::ITALIAN2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ROMANIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::ROMANCE2ENGLISH,
|
||||
MarianConfigResources::ROMANCE2ENGLISH,
|
||||
MarianVocabResources::ROMANCE2ENGLISH,
|
||||
MarianSpmResources::ROMANCE2ENGLISH,
|
||||
MarianPrefix::ROMANIAN2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const GERMAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::GERMAN2ENGLISH,
|
||||
MarianConfigResources::GERMAN2ENGLISH,
|
||||
MarianVocabResources::GERMAN2ENGLISH,
|
||||
MarianSpmResources::GERMAN2ENGLISH,
|
||||
MarianPrefix::GERMAN2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const RUSSIAN2ENGLISH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::RUSSIAN2ENGLISH,
|
||||
MarianConfigResources::RUSSIAN2ENGLISH,
|
||||
MarianVocabResources::RUSSIAN2ENGLISH,
|
||||
MarianSpmResources::RUSSIAN2ENGLISH,
|
||||
MarianPrefix::RUSSIAN2ENGLISH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
|
||||
pub const FRENCH2GERMAN: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::FRENCH2GERMAN,
|
||||
MarianConfigResources::FRENCH2GERMAN,
|
||||
MarianVocabResources::FRENCH2GERMAN,
|
||||
MarianSpmResources::FRENCH2GERMAN,
|
||||
MarianPrefix::FRENCH2GERMAN,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const GERMAN2FRENCH: (
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
(&'static str, &'static str),
|
||||
Option<&'static str>,
|
||||
ModelType,
|
||||
) = (
|
||||
MarianModelResources::GERMAN2FRENCH,
|
||||
MarianConfigResources::GERMAN2FRENCH,
|
||||
MarianVocabResources::GERMAN2FRENCH,
|
||||
MarianSpmResources::GERMAN2FRENCH,
|
||||
MarianPrefix::GERMAN2FRENCH,
|
||||
ModelType::Marian,
|
||||
);
|
||||
pub const ENGLISH2FRENCH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2FRENCH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2FRENCH_V2: RemoteTranslationResources = Self {
|
||||
model_resource: T5ModelResources::T5_BASE,
|
||||
config_resource: T5ConfigResources::T5_BASE,
|
||||
vocab_resource: T5VocabResources::T5_BASE,
|
||||
merges_resource: T5VocabResources::T5_BASE,
|
||||
prefix: T5Prefix::ENGLISH2FRENCH,
|
||||
model_type: ModelType::T5,
|
||||
};
|
||||
pub const ENGLISH2GERMAN_V2: RemoteTranslationResources = Self {
|
||||
model_resource: T5ModelResources::T5_BASE,
|
||||
config_resource: T5ConfigResources::T5_BASE,
|
||||
vocab_resource: T5VocabResources::T5_BASE,
|
||||
merges_resource: T5VocabResources::T5_BASE,
|
||||
prefix: T5Prefix::ENGLISH2GERMAN,
|
||||
model_type: ModelType::T5,
|
||||
};
|
||||
pub const ENGLISH2CATALAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2CATALAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2SPANISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2SPANISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2PORTUGUESE: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2PORTUGUESE,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2ITALIAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2ITALIAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2ROMANIAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2ROMANCE,
|
||||
config_resource: MarianConfigResources::ENGLISH2ROMANCE,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2ROMANCE,
|
||||
merges_resource: MarianSpmResources::ENGLISH2ROMANCE,
|
||||
prefix: MarianPrefix::ENGLISH2ROMANIAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2GERMAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2GERMAN,
|
||||
config_resource: MarianConfigResources::ENGLISH2GERMAN,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2GERMAN,
|
||||
merges_resource: MarianSpmResources::ENGLISH2GERMAN,
|
||||
prefix: MarianPrefix::ENGLISH2GERMAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ENGLISH2RUSSIAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ENGLISH2RUSSIAN,
|
||||
config_resource: MarianConfigResources::ENGLISH2RUSSIAN,
|
||||
vocab_resource: MarianVocabResources::ENGLISH2RUSSIAN,
|
||||
merges_resource: MarianSpmResources::ENGLISH2RUSSIAN,
|
||||
prefix: MarianPrefix::ENGLISH2RUSSIAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const FRENCH2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::FRENCH2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const CATALAN2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::CATALAN2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const SPANISH2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::SPANISH2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const PORTUGUESE2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::PORTUGUESE2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ITALIAN2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::ITALIAN2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const ROMANIAN2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::ROMANCE2ENGLISH,
|
||||
config_resource: MarianConfigResources::ROMANCE2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::ROMANCE2ENGLISH,
|
||||
merges_resource: MarianSpmResources::ROMANCE2ENGLISH,
|
||||
prefix: MarianPrefix::ROMANIAN2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const GERMAN2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::GERMAN2ENGLISH,
|
||||
config_resource: MarianConfigResources::GERMAN2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::GERMAN2ENGLISH,
|
||||
merges_resource: MarianSpmResources::GERMAN2ENGLISH,
|
||||
prefix: MarianPrefix::GERMAN2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const RUSSIAN2ENGLISH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::RUSSIAN2ENGLISH,
|
||||
config_resource: MarianConfigResources::RUSSIAN2ENGLISH,
|
||||
vocab_resource: MarianVocabResources::RUSSIAN2ENGLISH,
|
||||
merges_resource: MarianSpmResources::RUSSIAN2ENGLISH,
|
||||
prefix: MarianPrefix::RUSSIAN2ENGLISH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const FRENCH2GERMAN: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::FRENCH2GERMAN,
|
||||
config_resource: MarianConfigResources::FRENCH2GERMAN,
|
||||
vocab_resource: MarianVocabResources::FRENCH2GERMAN,
|
||||
merges_resource: MarianSpmResources::FRENCH2GERMAN,
|
||||
prefix: MarianPrefix::FRENCH2GERMAN,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
pub const GERMAN2FRENCH: RemoteTranslationResources = Self {
|
||||
model_resource: MarianModelResources::GERMAN2FRENCH,
|
||||
config_resource: MarianConfigResources::GERMAN2FRENCH,
|
||||
vocab_resource: MarianVocabResources::GERMAN2FRENCH,
|
||||
merges_resource: MarianSpmResources::GERMAN2FRENCH,
|
||||
prefix: MarianPrefix::GERMAN2FRENCH,
|
||||
model_type: ModelType::Marian,
|
||||
};
|
||||
}
|
||||
|
||||
/// # Configuration for text translation
|
||||
@ -463,37 +328,44 @@ impl TranslationConfig {
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(language: Language, device: Device) -> TranslationConfig {
|
||||
let (model_resource, config_resource, vocab_resource, merges_resource, prefix, model_type) =
|
||||
match language {
|
||||
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
|
||||
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
|
||||
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
|
||||
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
|
||||
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
|
||||
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
|
||||
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
|
||||
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
|
||||
let translation_resource = match language {
|
||||
Language::EnglishToFrench => RemoteTranslationResources::ENGLISH2FRENCH,
|
||||
Language::EnglishToCatalan => RemoteTranslationResources::ENGLISH2CATALAN,
|
||||
Language::EnglishToSpanish => RemoteTranslationResources::ENGLISH2SPANISH,
|
||||
Language::EnglishToPortuguese => RemoteTranslationResources::ENGLISH2PORTUGUESE,
|
||||
Language::EnglishToItalian => RemoteTranslationResources::ENGLISH2ITALIAN,
|
||||
Language::EnglishToRomanian => RemoteTranslationResources::ENGLISH2ROMANIAN,
|
||||
Language::EnglishToGerman => RemoteTranslationResources::ENGLISH2GERMAN,
|
||||
Language::EnglishToRussian => RemoteTranslationResources::ENGLISH2RUSSIAN,
|
||||
|
||||
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
|
||||
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
|
||||
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
|
||||
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
|
||||
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
|
||||
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
|
||||
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
|
||||
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
|
||||
Language::FrenchToEnglish => RemoteTranslationResources::FRENCH2ENGLISH,
|
||||
Language::CatalanToEnglish => RemoteTranslationResources::CATALAN2ENGLISH,
|
||||
Language::SpanishToEnglish => RemoteTranslationResources::SPANISH2ENGLISH,
|
||||
Language::PortugueseToEnglish => RemoteTranslationResources::PORTUGUESE2ENGLISH,
|
||||
Language::ItalianToEnglish => RemoteTranslationResources::ITALIAN2ENGLISH,
|
||||
Language::RomanianToEnglish => RemoteTranslationResources::ROMANIAN2ENGLISH,
|
||||
Language::GermanToEnglish => RemoteTranslationResources::GERMAN2ENGLISH,
|
||||
Language::RussianToEnglish => RemoteTranslationResources::RUSSIAN2ENGLISH,
|
||||
|
||||
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
|
||||
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
|
||||
Language::EnglishToFrenchV2 => RemoteTranslationResources::ENGLISH2FRENCH_V2,
|
||||
Language::EnglishToGermanV2 => RemoteTranslationResources::ENGLISH2GERMAN_V2,
|
||||
|
||||
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
|
||||
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
|
||||
};
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(model_resource));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(config_resource));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(vocab_resource));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(merges_resource));
|
||||
let prefix = match prefix {
|
||||
Language::FrenchToGerman => RemoteTranslationResources::FRENCH2GERMAN,
|
||||
Language::GermanToFrench => RemoteTranslationResources::GERMAN2FRENCH,
|
||||
};
|
||||
let model_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
translation_resource.model_resource,
|
||||
));
|
||||
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
translation_resource.config_resource,
|
||||
));
|
||||
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
translation_resource.vocab_resource,
|
||||
));
|
||||
let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
|
||||
translation_resource.merges_resource,
|
||||
));
|
||||
let prefix = match translation_resource.prefix {
|
||||
Some(value) => Some(value.to_string()),
|
||||
None => None,
|
||||
};
|
||||
@ -516,7 +388,7 @@ impl TranslationConfig {
|
||||
num_return_sequences: 1,
|
||||
device,
|
||||
prefix,
|
||||
model_type,
|
||||
model_type: translation_resource.model_type,
|
||||
}
|
||||
}
|
||||
|
||||
@ -740,10 +612,10 @@ impl TranslationModel {
|
||||
pub fn translate(&self, texts: &[&str]) -> Vec<String> {
|
||||
match &self.prefix {
|
||||
Some(value) => {
|
||||
let texts: Vec<String> = texts
|
||||
.into_iter()
|
||||
let texts = texts
|
||||
.iter()
|
||||
.map(|&v| format!("{} {}", value, v))
|
||||
.collect();
|
||||
.collect::<Vec<String>>();
|
||||
self.model
|
||||
.generate(Some(texts.iter().map(AsRef::as_ref).collect()), None)
|
||||
}
|
||||
|
@ -447,13 +447,13 @@ impl ZeroShotClassificationModel {
|
||||
let label_sentences: Vec<String> = match template {
|
||||
Some(function) => labels.iter().map(|label| function(label)).collect(),
|
||||
None => labels
|
||||
.into_iter()
|
||||
.iter()
|
||||
.map(|label| format!("This example is about {}.", label))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let text_pair_list = inputs
|
||||
.into_iter()
|
||||
.iter()
|
||||
.cartesian_product(label_sentences.iter())
|
||||
.map(|(&s, label)| (s, label.as_str()))
|
||||
.collect();
|
||||
|
@ -198,18 +198,18 @@ impl T5Attention {
|
||||
let length = temp_value.size()[2];
|
||||
temp_value = temp_value.slice(2, length - 1, length, 1);
|
||||
};
|
||||
if attention_mask.is_some() {
|
||||
temp_value = temp_value + attention_mask.unwrap();
|
||||
if let Some(attention_mask) = attention_mask {
|
||||
temp_value += attention_mask;
|
||||
};
|
||||
Some(temp_value)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let position_bias = if position_bias.is_none() {
|
||||
calculated_position_bias.as_ref().unwrap()
|
||||
let position_bias = if let Some(position_bias) = position_bias {
|
||||
position_bias
|
||||
} else {
|
||||
position_bias.unwrap()
|
||||
calculated_position_bias.as_ref().unwrap()
|
||||
};
|
||||
|
||||
scores += position_bias;
|
||||
@ -247,7 +247,7 @@ impl T5Attention {
|
||||
let mut num_buckets = num_buckets;
|
||||
let mut ret = n.zeros_like();
|
||||
let n = if bidirectional {
|
||||
num_buckets = num_buckets / 2;
|
||||
num_buckets /= 2;
|
||||
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
|
||||
n.abs()
|
||||
} else {
|
||||
|
Loading…
Reference in New Issue
Block a user