Addition of integration tests for zero-shot classification

This commit is contained in:
Guillaume B 2020-09-05 15:02:34 +02:00
parent 0ea9148e3f
commit b52b0cb005
4 changed files with 134 additions and 4 deletions

View File

@ -0,0 +1,37 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
fn main() -> anyhow::Result<()> {
// Set-up model
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
println!("{:?}", output);
Ok(())
}

View File

@ -77,7 +77,7 @@ use std::collections::HashMap;
use tch::nn::VarStore;
use tch::{nn, no_grad, Device, Kind, Tensor};
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # Label generated by a `SequenceClassificationModel`
pub struct Label {
/// Label String representation

View File

@ -361,7 +361,7 @@ impl ZeroShotClassificationModel {
Some(function) => labels.iter().map(|label| function(label)).collect(),
None => labels
.into_iter()
.map(|label| format!("This example is {}.", label))
.map(|label| format!("This example is about {}.", label))
.collect(),
};
@ -442,6 +442,26 @@ impl ZeroShotClassificationModel {
/// # Ok(())
/// # }
/// ```
///
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = [
/// Label {
/// text: "politics".to_string(),
/// score: 0.959,
/// id: 0,
/// sentence: 0,
/// },
/// Label {
/// text: "economics".to_string(),
/// score: 0.655,
/// id: 2,
/// sentence: 1,
/// },
/// ]
/// .to_vec();
/// ```
pub fn predict(
&self,
inputs: &[&str],
@ -465,7 +485,6 @@ impl ZeroShotClassificationModel {
let scores = output.softmax(1, Float).select(-1, -1);
let label_indices = scores.as_ref().argmax(-1, true).squeeze1(1);
label_indices.print();
let scores = scores
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
@ -576,7 +595,8 @@ impl ZeroShotClassificationModel {
/// sentence: 1,
/// },
/// ],
/// ];
/// ]
/// .to_vec();
/// ```
pub fn predict_multilabel(
&self,

View File

@ -3,6 +3,7 @@ use rust_bert::bart::{
BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
@ -156,3 +157,75 @@ about exoplanets like K2-18b."];
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification() -> anyhow::Result<()> {
// Set-up model model
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict(
&[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
assert_eq!(output.len(), 2);
// Prediction scores
assert_eq!(output[0].text, "politics");
assert!((output[0].score - 0.9679).abs() < 1e-4);
assert_eq!(output[1].text, "economics");
assert!((output[1].score - 0.5208).abs() < 1e-4);
Ok(())
}
#[test]
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
// Set-up model model
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
let candidate_labels = &["politics", "public health", "economics", "sports"];
let output = sequence_classification_model.predict_multilabel(
&[input_sentence, input_sequence_2],
candidate_labels,
Some(Box::new(|label: &str| {
format!("This example is about {}.", label)
})),
128,
);
assert_eq!(output.len(), 2);
assert_eq!(output[0].len(), candidate_labels.len());
// First sentence label scores
assert_eq!(output[0][0].text, "politics");
assert!((output[0][0].score - 0.9805).abs() < 1e-4);
assert_eq!(output[0][1].text, "public health");
assert!((output[0][1].score - 0.0130).abs() < 1e-4);
assert_eq!(output[0][2].text, "economics");
assert!((output[0][2].score - 0.0041).abs() < 1e-4);
assert_eq!(output[0][3].text, "sports");
assert!((output[0][3].score - 0.0013).abs() < 1e-4);
// Second sentence label scores
assert_eq!(output[1][0].text, "politics");
assert!((output[1][0].score - 0.9432).abs() < 1e-4);
assert_eq!(output[1][1].text, "public health");
assert!((output[1][1].score - 0.0045).abs() < 1e-4);
assert_eq!(output[1][2].text, "economics");
assert!((output[1][2].score - 0.9001).abs() < 1e-4);
assert_eq!(output[1][3].text, "sports");
assert!((output[1][3].score - 0.0004).abs() < 1e-4);
Ok(())
}