mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-09-11 12:55:34 +03:00
Initial commit for zero-shot classification
This commit is contained in:
parent
c6886096d2
commit
919cbff3c9
@ -30,7 +30,7 @@ all-tests = []
|
||||
features = ["doc-only"]
|
||||
|
||||
[dependencies]
|
||||
rust_tokenizers = {version = "~5.0.0", path = "E:/Coding/backup-rust/rust-tokenizers/main"}
|
||||
rust_tokenizers = "~5.0.0"
|
||||
tch = "~0.2.0"
|
||||
serde_json = "1.0.56"
|
||||
serde = { version = "1.0.114", features = ["derive"] }
|
||||
|
@ -18,6 +18,9 @@ pub enum RustBertError {
|
||||
|
||||
#[error("Invalid configuration error: {0}")]
|
||||
InvalidConfigurationError(String),
|
||||
|
||||
#[error("Value error: {0}")]
|
||||
ValueError(String),
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for RustBertError {
|
||||
|
@ -45,6 +45,7 @@ use std::path::Path;
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
/// # Identifies the type of model
|
||||
pub enum ModelType {
|
||||
Bart,
|
||||
Bert,
|
||||
DistilBert,
|
||||
Roberta,
|
||||
@ -57,6 +58,8 @@ pub enum ModelType {
|
||||
|
||||
/// # Abstraction that holds a model configuration, can be of any of the supported models
|
||||
pub enum ConfigOption {
|
||||
/// Bart configuration
|
||||
Bart(BartConfig),
|
||||
/// Bert configuration
|
||||
Bert(BertConfig),
|
||||
/// DistilBert configuration
|
||||
@ -91,6 +94,7 @@ impl ConfigOption {
|
||||
/// Interface method to load a configuration from file
|
||||
pub fn from_file(model_type: ModelType, path: &Path) -> Self {
|
||||
match model_type {
|
||||
ModelType::Bart => ConfigOption::Bart(BartConfig::from_file(path)),
|
||||
ModelType::Bert | ModelType::Roberta | ModelType::XLMRoberta => {
|
||||
ConfigOption::Bert(BertConfig::from_file(path))
|
||||
}
|
||||
@ -104,6 +108,9 @@ impl ConfigOption {
|
||||
|
||||
pub fn get_label_mapping(self) -> HashMap<i64, String> {
|
||||
match self {
|
||||
Self::Bart(config) => config
|
||||
.id2label
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
Self::Bert(config) => config
|
||||
.id2label
|
||||
.expect("No label dictionary (id2label) provided in configuration file"),
|
||||
@ -148,7 +155,7 @@ impl TokenizerOption {
|
||||
strip_accents.unwrap_or(lower_case),
|
||||
)?)
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
ModelType::Roberta | ModelType::Bart => {
|
||||
if strip_accents.is_some() {
|
||||
return Err(RustBertError::InvalidConfigurationError(format!(
|
||||
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",
|
||||
|
@ -270,3 +270,4 @@ pub mod sequence_classification;
|
||||
pub mod summarization;
|
||||
pub mod token_classification;
|
||||
pub mod translation;
|
||||
pub mod zero_shot_classification;
|
||||
|
@ -171,7 +171,7 @@ impl Default for SequenceClassificationConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Abstraction that holds one particular token sequence classifier model, for any of the supported models
|
||||
/// # Abstraction that holds one particular sequence classification model, for any of the supported models
|
||||
pub enum SequenceClassificationOption {
|
||||
/// Bert for Sequence Classification
|
||||
Bert(BertForSequenceClassification),
|
||||
|
281
src/pipelines/zero_shot_classification.rs
Normal file
281
src/pipelines/zero_shot_classification.rs
Normal file
@ -0,0 +1,281 @@
|
||||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
|
||||
// Copyright 2019-2020 Guillaume Becquin
|
||||
// Copyright 2020 Maarten van Gompel
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::albert::AlbertForSequenceClassification;
|
||||
use crate::bart::{
|
||||
BartConfigResources, BartForSequenceClassification, BartMergesResources, BartModelResources,
|
||||
BartVocabResources,
|
||||
};
|
||||
use crate::bert::BertForSequenceClassification;
|
||||
use crate::distilbert::DistilBertModelClassifier;
|
||||
use crate::pipelines::common::{ConfigOption, ModelType};
|
||||
use crate::resources::{RemoteResource, Resource};
|
||||
use crate::roberta::RobertaForSequenceClassification;
|
||||
use crate::RustBertError;
|
||||
use std::borrow::Borrow;
|
||||
use tch::{nn, Device, Tensor};
|
||||
|
||||
/// # Configuration for ZeroShotClassificationModel
|
||||
/// Contains information regarding the model to load and device to place the model on.
|
||||
pub struct ZeroShotClassificationConfig {
|
||||
/// Model type
|
||||
pub model_type: ModelType,
|
||||
/// Model weights resource (default: pretrained BERT model on CoNLL)
|
||||
pub model_resource: Resource,
|
||||
/// Config resource (default: pretrained BERT model on CoNLL)
|
||||
pub config_resource: Resource,
|
||||
/// Vocab resource (default: pretrained BERT model on CoNLL)
|
||||
pub vocab_resource: Resource,
|
||||
/// Merges resource (default: None)
|
||||
pub merges_resource: Option<Resource>,
|
||||
/// Automatically lower case all input upon tokenization (assumes a lower-cased model)
|
||||
pub lower_case: bool,
|
||||
/// Flag indicating if the tokenizer should strip accents (normalization). Only used for BERT / ALBERT models
|
||||
pub strip_accents: Option<bool>,
|
||||
/// Flag indicating if the tokenizer should add a white space before each tokenized input (needed for some Roberta models)
|
||||
pub add_prefix_space: Option<bool>,
|
||||
/// Device to place the model on (default: CUDA/GPU when available)
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl ZeroShotClassificationConfig {
|
||||
/// Instantiate a new zero shot classification configuration of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
|
||||
/// * model - The `Resource` pointing to the model to load (e.g. model.ot)
|
||||
/// * config - The `Resource' pointing to the model configuration to load (e.g. config.json)
|
||||
/// * vocab - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
|
||||
/// * vocab - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
|
||||
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
|
||||
pub fn new(
|
||||
model_type: ModelType,
|
||||
model_resource: Resource,
|
||||
config_resource: Resource,
|
||||
vocab_resource: Resource,
|
||||
merges_resource: Option<Resource>,
|
||||
lower_case: bool,
|
||||
strip_accents: impl Into<Option<bool>>,
|
||||
add_prefix_space: impl Into<Option<bool>>,
|
||||
) -> ZeroShotClassificationConfig {
|
||||
ZeroShotClassificationConfig {
|
||||
model_type,
|
||||
model_resource,
|
||||
config_resource,
|
||||
vocab_resource,
|
||||
merges_resource,
|
||||
lower_case,
|
||||
strip_accents: strip_accents.into(),
|
||||
add_prefix_space: add_prefix_space.into(),
|
||||
device: Device::cuda_if_available(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ZeroShotClassificationConfig {
|
||||
/// Provides a defaultSST-2 sentiment analysis model (English)
|
||||
fn default() -> ZeroShotClassificationConfig {
|
||||
ZeroShotClassificationConfig {
|
||||
model_type: ModelType::DistilBert,
|
||||
model_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartModelResources::BART_MNLI,
|
||||
)),
|
||||
config_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartConfigResources::BART_MNLI,
|
||||
)),
|
||||
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartVocabResources::BART_MNLI,
|
||||
)),
|
||||
merges_resource: Some(Resource::Remote(RemoteResource::from_pretrained(
|
||||
BartMergesResources::BART_MNLI,
|
||||
))),
|
||||
lower_case: false,
|
||||
strip_accents: None,
|
||||
add_prefix_space: None,
|
||||
device: Device::cuda_if_available(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Abstraction that holds one particular zero shot classification model, for any of the supported models
|
||||
/// The models are using a classification architecture that should be trained on Natural Language Inference.
|
||||
/// The models should output a Tensor of size > 2 in the label dimension, with the first logit corresponding
|
||||
/// to contradiction and the last logit corresponding to entailment.
|
||||
pub enum ZeroShotClassificationOption {
|
||||
/// Bart for Sequence Classification
|
||||
Bart(BartForSequenceClassification),
|
||||
/// Bert for Sequence Classification
|
||||
Bert(BertForSequenceClassification),
|
||||
/// DistilBert for Sequence Classification
|
||||
DistilBert(DistilBertModelClassifier),
|
||||
/// Roberta for Sequence Classification
|
||||
Roberta(RobertaForSequenceClassification),
|
||||
/// XLMRoberta for Sequence Classification
|
||||
XLMRoberta(RobertaForSequenceClassification),
|
||||
/// Albert for Sequence Classification
|
||||
Albert(AlbertForSequenceClassification),
|
||||
}
|
||||
|
||||
impl ZeroShotClassificationOption {
|
||||
/// Instantiate a new zero shot classification model of the supplied type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded)
|
||||
/// * `p` - `tch::nn::Path` path to the model file to load (e.g. model.ot)
|
||||
/// * `config` - A configuration (the model type of the configuration must be compatible with the value for
|
||||
/// `model_type`)
|
||||
pub fn new<'p, P>(model_type: ModelType, p: P, config: &ConfigOption) -> Self
|
||||
where
|
||||
P: Borrow<nn::Path<'p>>,
|
||||
{
|
||||
match model_type {
|
||||
ModelType::Bart => {
|
||||
if let ConfigOption::Bart(config) = config {
|
||||
ZeroShotClassificationOption::Bart(BartForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BartConfig for Bart!");
|
||||
}
|
||||
}
|
||||
ModelType::Bert => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
ZeroShotClassificationOption::Bert(BertForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Bert!");
|
||||
}
|
||||
}
|
||||
ModelType::DistilBert => {
|
||||
if let ConfigOption::DistilBert(config) = config {
|
||||
ZeroShotClassificationOption::DistilBert(DistilBertModelClassifier::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a DistilBertConfig for DistilBert!");
|
||||
}
|
||||
}
|
||||
ModelType::Roberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
ZeroShotClassificationOption::Roberta(RobertaForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
}
|
||||
ModelType::XLMRoberta => {
|
||||
if let ConfigOption::Bert(config) = config {
|
||||
ZeroShotClassificationOption::XLMRoberta(RobertaForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply a BertConfig for Roberta!");
|
||||
}
|
||||
}
|
||||
ModelType::Albert => {
|
||||
if let ConfigOption::Albert(config) = config {
|
||||
ZeroShotClassificationOption::Albert(AlbertForSequenceClassification::new(
|
||||
p, config,
|
||||
))
|
||||
} else {
|
||||
panic!("You can only supply an AlbertConfig for Albert!");
|
||||
}
|
||||
}
|
||||
ModelType::Electra => {
|
||||
panic!("SequenceClassification not implemented for Electra!");
|
||||
}
|
||||
ModelType::Marian => {
|
||||
panic!("SequenceClassification not implemented for Marian!");
|
||||
}
|
||||
ModelType::T5 => {
|
||||
panic!("SequenceClassification not implemented for T5!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `ModelType` for this SequenceClassificationOption
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
match *self {
|
||||
Self::Bart(_) => ModelType::Bart,
|
||||
Self::Bert(_) => ModelType::Bert,
|
||||
Self::Roberta(_) => ModelType::Roberta,
|
||||
Self::XLMRoberta(_) => ModelType::Roberta,
|
||||
Self::DistilBert(_) => ModelType::DistilBert,
|
||||
Self::Albert(_) => ModelType::Albert,
|
||||
}
|
||||
}
|
||||
|
||||
/// Interface method to forward_t() of the particular models.
|
||||
pub fn forward_t(
|
||||
&self,
|
||||
input_ids: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
token_type_ids: Option<Tensor>,
|
||||
position_ids: Option<Tensor>,
|
||||
input_embeds: Option<Tensor>,
|
||||
train: bool,
|
||||
) -> Result<Tensor, RustBertError> {
|
||||
match *self {
|
||||
Self::Bart(ref model) => match input_ids {
|
||||
Some(input_ids) => Ok(model
|
||||
.forward_t(&input_ids, mask.as_ref(), None, None, None, train)
|
||||
.0),
|
||||
None => {
|
||||
return {
|
||||
Err(RustBertError::ValueError(
|
||||
"`input_ids` must be provided when using a BART model".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
},
|
||||
Self::Bert(ref model) => Ok(model
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.0),
|
||||
Self::DistilBert(ref model) => Ok(model
|
||||
.forward_t(input_ids, mask, input_embeds, train)
|
||||
.expect("Error in distilbert forward_t")
|
||||
.0),
|
||||
Self::Roberta(ref model) | Self::XLMRoberta(ref model) => Ok(model
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.0),
|
||||
Self::Albert(ref model) => Ok(model
|
||||
.forward_t(
|
||||
input_ids,
|
||||
mask,
|
||||
token_type_ids,
|
||||
position_ids,
|
||||
input_embeds,
|
||||
train,
|
||||
)
|
||||
.0),
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user