From daa6dba2d2e2b598552ec0ff65d126dd91d00e0d Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Sat, 12 Sep 2020 15:11:56 +0200 Subject: [PATCH] Updated Albert (clippy warnings) --- Cargo.toml | 2 +- examples/albert.rs | 19 +++- src/albert/albert_model.rs | 104 +++++++++++++--------- src/albert/encoder.rs | 39 ++++---- src/pipelines/question_answering.rs | 2 +- src/pipelines/sequence_classification.rs | 2 +- src/pipelines/token_classification.rs | 2 +- src/pipelines/zero_shot_classification.rs | 2 +- tests/albert.rs | 50 ++++++----- 9 files changed, 134 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1f8b8c8..429a509 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-bert" -version = "0.9.0" +version = "0.10.0" authors = ["Guillaume Becquin "] edition = "2018" description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)" diff --git a/examples/albert.rs b/examples/albert.rs index 05783eb..5e88256 100644 --- a/examples/albert.rs +++ b/examples/albert.rs @@ -70,12 +70,23 @@ fn main() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, _) = + let model_output = no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); - println!("{:?}", output.double_value(&[0, 0, 0])); + println!( + "{:?}", + model_output.prediction_scores.double_value(&[0, 0, 0]) + ); // Print masked tokens - let index_1 = output.get(0).get(4).argmax(0, false); - let index_2 = output.get(1).get(7).argmax(0, false); + let index_1 = model_output + .prediction_scores + .get(0) + .get(4) + .argmax(0, false); + let index_2 = model_output + .prediction_scores + .get(1) + .get(7) + .argmax(0, false); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); diff --git a/src/albert/albert_model.rs b/src/albert/albert_model.rs index 52eaa0b..d49e771 100644 --- a/src/albert/albert_model.rs +++ b/src/albert/albert_model.rs @@ -209,7 +209,7 @@ impl AlbertModel { /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// - /// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// albert_model /// .forward_t( /// Some(input_tensor), @@ -268,18 +268,20 @@ impl AlbertModel { } }; - let (hidden_state, all_hidden_states, all_attentions) = + let transformer_output = self.encoder .forward_t(&embedding_output, Some(extended_attention_mask), train); - let pooled_output = self.pooler.forward(&hidden_state.select(1, 0)); + let pooled_output = self + .pooler + .forward(&transformer_output.hidden_state.select(1, 0)); let pooled_output = (self.pooler_activation)(&pooled_output); Ok(AlbertOutput { - hidden_state, + hidden_state: transformer_output.hidden_state, pooled_output, - all_hidden_states, - all_attentions, + all_hidden_states: transformer_output.all_hidden_states, + all_attentions: transformer_output.all_attentions, }) } } @@ -429,7 +431,7 @@ impl AlbertForMaskedLM { /// let position_ids = Tensor::arange(sequence_length, (Int64, device)) /// .expand(&[batch_size, sequence_length], true); /// - /// let (output, all_hidden_states, all_attentions) = no_grad(|| { + /// let masked_lm_output = no_grad(|| { /// albert_model.forward_t( /// Some(input_tensor), /// Some(mask), @@ -448,7 +450,7 @@ impl AlbertForMaskedLM { position_ids: Option, input_embeds: Option, train: bool, - ) -> (Tensor, Option>, Option>>) { + ) -> AlbertMaskedLMOutput { let base_model_output = self .albert .forward_t( @@ -461,11 +463,11 @@ impl AlbertForMaskedLM { ) .unwrap(); let prediction_scores = self.predictions.forward(&base_model_output.hidden_state); - ( + AlbertMaskedLMOutput { prediction_scores, - base_model_output.all_hidden_states, - base_model_output.all_attentions, - ) + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, + } } } @@ -571,7 +573,7 @@ impl AlbertForSequenceClassification { /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// - /// let (output, all_hidden_states, all_attentions) = no_grad(|| { + /// let classification_output = no_grad(|| { /// albert_model /// .forward_t(Some(input_tensor), /// Some(mask), @@ -589,7 +591,7 @@ impl AlbertForSequenceClassification { position_ids: Option, input_embeds: Option, train: bool, - ) -> (Tensor, Option>, Option>>) { + ) -> AlbertSequenceClassificationOutput { let base_model_output = self .albert .forward_t( @@ -605,11 +607,11 @@ impl AlbertForSequenceClassification { .pooled_output .apply_t(&self.dropout, train) .apply(&self.classifier); - ( + AlbertSequenceClassificationOutput { logits, - base_model_output.all_hidden_states, - base_model_output.all_attentions, - ) + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, + } } } @@ -712,7 +714,7 @@ impl AlbertForTokenClassification { /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// - /// let (output, all_hidden_states, all_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// albert_model /// .forward_t(Some(input_tensor), /// Some(mask), @@ -730,7 +732,7 @@ impl AlbertForTokenClassification { position_ids: Option, input_embeds: Option, train: bool, - ) -> (Tensor, Option>, Option>>) { + ) -> AlbertTokenClassificationOutput { let base_model_output = self .albert .forward_t( @@ -746,11 +748,11 @@ impl AlbertForTokenClassification { .hidden_state .apply_t(&self.dropout, train) .apply(&self.classifier); - ( + AlbertTokenClassificationOutput { logits, - base_model_output.all_hidden_states, - base_model_output.all_attentions, - ) + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, + } } } @@ -843,7 +845,7 @@ impl AlbertForQuestionAnswering { /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// - /// let (start_logits, end_logits, all_hidden_states, all_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// albert_model /// .forward_t(Some(input_tensor), /// Some(mask), @@ -861,12 +863,7 @@ impl AlbertForQuestionAnswering { position_ids: Option, input_embeds: Option, train: bool, - ) -> ( - Tensor, - Tensor, - Option>, - Option>>, - ) { + ) -> AlbertQuestionAnsweringOutput { let base_model_output = self .albert .forward_t( @@ -886,12 +883,12 @@ impl AlbertForQuestionAnswering { let start_logits = start_logits.squeeze1(-1); let end_logits = end_logits.squeeze1(-1); - ( + AlbertQuestionAnsweringOutput { start_logits, end_logits, - base_model_output.all_hidden_states, - base_model_output.all_attentions, - ) + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, + } } } @@ -990,7 +987,7 @@ impl AlbertForMultipleChoice { /// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device)); /// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true); /// - /// let (output, all_hidden_states, all_attentions) = no_grad(|| { + /// let model_output = no_grad(|| { /// albert_model /// .forward_t(Some(input_tensor), /// Some(mask), @@ -1008,7 +1005,7 @@ impl AlbertForMultipleChoice { position_ids: Option, input_embeds: Option, train: bool, - ) -> Result<(Tensor, Option>, Option>>), &'static str> { + ) -> Result { let (input_ids, input_embeds, num_choices) = match &input_ids { Some(input_value) => match &input_embeds { Some(_) => { @@ -1062,10 +1059,35 @@ impl AlbertForMultipleChoice { .apply(&self.classifier) .view((-1, num_choices)); - Ok(( + Ok(AlbertSequenceClassificationOutput { logits, - base_model_output.all_hidden_states, - base_model_output.all_attentions, - )) + all_hidden_states: base_model_output.all_hidden_states, + all_attentions: base_model_output.all_attentions, + }) } } + +pub struct AlbertMaskedLMOutput { + pub prediction_scores: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>>, +} + +pub struct AlbertSequenceClassificationOutput { + pub logits: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>>, +} + +pub struct AlbertTokenClassificationOutput { + pub logits: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>>, +} + +pub struct AlbertQuestionAnsweringOutput { + pub start_logits: Tensor, + pub end_logits: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>>, +} diff --git a/src/albert/encoder.rs b/src/albert/encoder.rs index 2925cf2..7aaade0 100644 --- a/src/albert/encoder.rs +++ b/src/albert/encoder.rs @@ -149,22 +149,17 @@ impl AlbertLayerGroup { let mut hidden_state = hidden_states.copy(); let mut attention_weights: Option; - let mut layers = self.layers.iter(); - loop { - match layers.next() { - Some(layer) => { - if let Some(hidden_states) = all_hidden_states.borrow_mut() { - hidden_states.push(hidden_state.as_ref().copy()); - }; - let temp = layer.forward_t(&hidden_state, &mask, train); - hidden_state = temp.0; - attention_weights = temp.1; - if let Some(attentions) = all_attentions.borrow_mut() { - attentions.push(attention_weights.as_ref().unwrap().copy()); - }; - } - None => break, + for layer in &self.layers { + if let Some(hidden_states) = all_hidden_states.borrow_mut() { + hidden_states.push(hidden_state.as_ref().copy()); + }; + + let temp = layer.forward_t(&hidden_state, &mask, train); + hidden_state = temp.0; + attention_weights = temp.1; + if let Some(attentions) = all_attentions.borrow_mut() { + attentions.push(attention_weights.as_ref().unwrap().copy()); }; } @@ -226,7 +221,7 @@ impl AlbertTransformer { hidden_states: &Tensor, mask: Option, train: bool, - ) -> (Tensor, Option>, Option>>) { + ) -> AlbertTransformerOutput { let mut hidden_state = hidden_states.apply(&self.embedding_hidden_mapping_in); let mut all_hidden_states: Option> = if self.output_hidden_states { @@ -256,6 +251,16 @@ impl AlbertTransformer { }; } - (hidden_state, all_hidden_states, all_attentions) + AlbertTransformerOutput { + hidden_state, + all_hidden_states, + all_attentions, + } } } + +pub struct AlbertTransformerOutput { + pub hidden_state: Tensor, + pub all_hidden_states: Option>, + pub all_attentions: Option>>, +} diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index 9fcad39..d036ca0 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -370,7 +370,7 @@ impl QuestionAnsweringOption { } Self::Albert(ref model) => { let outputs = model.forward_t(input_ids, mask, None, None, input_embeds, train); - (outputs.0, outputs.1) + (outputs.start_logits, outputs.end_logits) } } } diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index d677292..5aa6e61 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -343,7 +343,7 @@ impl SequenceClassificationOption { input_embeds, train, ) - .0 + .logits } } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 7547754..2d0bcdb 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -450,7 +450,7 @@ impl TokenClassificationOption { input_embeds, train, ) - .0 + .logits } } } diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 9a7c194..ae1c660 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -374,7 +374,7 @@ impl ZeroShotClassificationOption { input_embeds, train, ) - .0 + .logits } } } diff --git a/tests/albert.rs b/tests/albert.rs index 1b201fc..642c1bb 100644 --- a/tests/albert.rs +++ b/tests/albert.rs @@ -61,18 +61,26 @@ fn albert_masked_lm() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, _, _) = + let model_output = no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); // Print masked tokens - let index_1 = output.get(0).get(4).argmax(0, false); - let index_2 = output.get(1).get(6).argmax(0, false); + let index_1 = model_output + .prediction_scores + .get(0) + .get(4) + .argmax(0, false); + let index_2 = model_output + .prediction_scores + .get(1) + .get(6) + .argmax(0, false); let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); assert_eq!("▁them", word_1); // Outputs "_them" : "Looks like one [them] is missing (? this is identical with the original implementation)" assert_eq!("▁grapes", word_2); // Outputs "grapes" : "It\'s like comparing [grapes] to apples" - assert!((output.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4); + assert!((model_output.prediction_scores.double_value(&[0, 0, 0]) - 4.6143).abs() < 1e-4); Ok(()) } @@ -127,17 +135,17 @@ fn albert_for_sequence_classification() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, all_hidden_states, all_attentions) = + let model_output = no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); - assert_eq!(output.size(), &[2, 3]); + assert_eq!(model_output.logits.size(), &[2, 3]); assert_eq!( config.num_hidden_layers as usize, - all_hidden_states.unwrap().len() + model_output.all_hidden_states.unwrap().len() ); assert_eq!( config.num_hidden_layers as usize, - all_attentions.unwrap().len() + model_output.all_attentions.unwrap().len() ); Ok(()) @@ -191,20 +199,20 @@ fn albert_for_multiple_choice() -> anyhow::Result<()> { .unsqueeze(0); // Forward pass - let (output, all_hidden_states, all_attentions) = no_grad(|| { + let model_output = no_grad(|| { albert_model .forward_t(Some(input_tensor), None, None, None, None, false) .unwrap() }); - assert_eq!(output.size(), &[1, 2]); + assert_eq!(model_output.logits.size(), &[1, 2]); assert_eq!( config.num_hidden_layers as usize, - all_hidden_states.unwrap().len() + model_output.all_hidden_states.unwrap().len() ); assert_eq!( config.num_hidden_layers as usize, - all_attentions.unwrap().len() + model_output.all_attentions.unwrap().len() ); Ok(()) @@ -262,17 +270,17 @@ fn albert_for_token_classification() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (output, all_hidden_states, all_attentions) = + let model_output = no_grad(|| bert_model.forward_t(Some(input_tensor), None, None, None, None, false)); - assert_eq!(output.size(), &[2, 12, 4]); + assert_eq!(model_output.logits.size(), &[2, 12, 4]); assert_eq!( config.num_hidden_layers as usize, - all_hidden_states.unwrap().len() + model_output.all_hidden_states.unwrap().len() ); assert_eq!( config.num_hidden_layers as usize, - all_attentions.unwrap().len() + model_output.all_attentions.unwrap().len() ); Ok(()) @@ -324,18 +332,18 @@ fn albert_for_question_answering() -> anyhow::Result<()> { let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); // Forward pass - let (start_scores, end_scores, all_hidden_states, all_attentions) = + let model_output = no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); - assert_eq!(start_scores.size(), &[2, 12]); - assert_eq!(end_scores.size(), &[2, 12]); + assert_eq!(model_output.start_logits.size(), &[2, 12]); + assert_eq!(model_output.end_logits.size(), &[2, 12]); assert_eq!( config.num_hidden_layers as usize, - all_hidden_states.unwrap().len() + model_output.all_hidden_states.unwrap().len() ); assert_eq!( config.num_hidden_layers as usize, - all_attentions.unwrap().len() + model_output.all_attentions.unwrap().len() ); Ok(())