mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-10-26 14:07:25 +03:00
Completed documentation for DistilBERT
This commit is contained in:
parent
6c6e5526ec
commit
496f7ba0bb
@ -19,7 +19,6 @@ use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
|
||||
use failure::err_msg;
|
||||
use rust_bert::Config;
|
||||
use rust_bert::bert::{BertConfig, BertForMaskedLM};
|
||||
use tch::kind::Kind::Int64;
|
||||
|
||||
|
||||
fn main() -> failure::Fallible<()> {
|
||||
|
@ -812,7 +812,7 @@ impl BertForQuestionAnswering {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (start_positions, end_positions, _, _) = no_grad(|| {
|
||||
/// let (start_scores, end_scores, _, _) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! # BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.)
|
||||
//!
|
||||
//! Implementation of the BERT language model (https://arxiv.org/abs/1810.04805 Devlin, Chang, Lee, Toutanova, 2018).
|
||||
//! Implementation of the BERT language model ([https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805) Devlin, Chang, Lee, Toutanova, 2018).
|
||||
//! The base model is implemented in the `bert::BertModel` struct. Several language model heads have also been implemented, including:
|
||||
//! - Masked language model: `bert::BertForMaskedLM`
|
||||
//! - Multiple choices: `bert:BertForMultipleChoice`
|
||||
@ -14,7 +14,7 @@
|
||||
//! The example below illustrate a Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights is expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
//! - `BertTokenizer` using a `vocab.txt` vocabulary
|
||||
//!
|
||||
//! ```no_run
|
||||
|
@ -22,12 +22,17 @@ use crate::Config;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # Activation function used in the feed-forward layer in the transformer blocks
|
||||
pub enum Activation {
|
||||
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
|
||||
gelu,
|
||||
/// Rectified Linear Unit
|
||||
relu,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # DistilBERT model configuration
|
||||
/// Defines the DistilBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
|
||||
pub struct DistilBertConfig {
|
||||
pub activation: Activation,
|
||||
pub attention_dropout: f64,
|
||||
@ -56,12 +61,40 @@ pub struct DistilBertConfig {
|
||||
|
||||
impl Config<DistilBertConfig> for DistilBertConfig {}
|
||||
|
||||
/// # DistilBERT Base model
|
||||
/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model
|
||||
/// It is made of the following blocks:
|
||||
/// - `embeddings`: `token`, `position` embeddings
|
||||
/// - `transformer`: Transformer made of a vector of layers. Each layer is made of a multi-head self-attention layer, layer norm and linear layers.
|
||||
pub struct DistilBertModel {
|
||||
embeddings: DistilBertEmbedding,
|
||||
transformer: Transformer,
|
||||
}
|
||||
|
||||
/// Defines the implementation of the DistilBertModel.
|
||||
impl DistilBertModel {
|
||||
/// Build a new `DistilBertModel`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
|
||||
let p = &(p / "distilbert");
|
||||
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
|
||||
@ -69,14 +102,49 @@ impl DistilBertModel {
|
||||
DistilBertModel { embeddings, transformer }
|
||||
}
|
||||
|
||||
pub fn _get_embeddings(&self) -> &nn::Embedding {
|
||||
self.embeddings._get_word_embeddings()
|
||||
}
|
||||
|
||||
pub fn _set_embeddings(&mut self, new_embeddings: nn::Embedding) {
|
||||
&self.embeddings._set_word_embeddings(new_embeddings);
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::Int64;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = DistilBertConfig::from_file(config_path);
|
||||
///# let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let input_embeddings = match input {
|
||||
@ -96,6 +164,12 @@ impl DistilBertModel {
|
||||
}
|
||||
}
|
||||
|
||||
/// # DistilBERT for sequence classification
|
||||
/// Base DistilBERT model with a pre-classifier and classifier heads to perform sentence or document-level classification
|
||||
/// It is made of the following blocks:
|
||||
/// - `distil_bert_model`: Base DistilBertModel
|
||||
/// - `pre_classifier`: DistilBERT linear layer for classification
|
||||
/// - `classifier`: DistilBERT linear layer for classification
|
||||
pub struct DistilBertModelClassifier {
|
||||
distil_bert_model: DistilBertModel,
|
||||
pre_classifier: nn::Linear,
|
||||
@ -104,6 +178,28 @@ pub struct DistilBertModelClassifier {
|
||||
}
|
||||
|
||||
impl DistilBertModelClassifier {
|
||||
/// Build a new `DistilBertModelClassifier` for sequence classification
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
|
||||
@ -113,6 +209,49 @@ impl DistilBertModelClassifier {
|
||||
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*) representing the logits for each class to predict
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::Int64;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = DistilBertConfig::from_file(config_path);
|
||||
///# let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
@ -131,6 +270,13 @@ impl DistilBertModelClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
/// # DistilBERT for masked language model
|
||||
/// Base DistilBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
|
||||
/// It is made of the following blocks:
|
||||
/// - `distil_bert_model`: Base DistilBertModel
|
||||
/// - `vocab_transform`:linear layer for classification of size (*hidden_dim*, *hidden_dim*)
|
||||
/// - `vocab_layer_norm`: layer normalization
|
||||
/// - `vocab_projector`: linear layer for classification of size (*hidden_dim*, *vocab_size*) with weights tied to the token embeddings
|
||||
pub struct DistilBertModelMaskedLM {
|
||||
distil_bert_model: DistilBertModel,
|
||||
vocab_transform: nn::Linear,
|
||||
@ -138,7 +284,30 @@ pub struct DistilBertModelMaskedLM {
|
||||
vocab_projector: nn::Linear,
|
||||
}
|
||||
|
||||
|
||||
impl DistilBertModelMaskedLM {
|
||||
/// Build a new `DistilBertModelMaskedLM` for sequence classification
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
|
||||
@ -149,6 +318,49 @@ impl DistilBertModelMaskedLM {
|
||||
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for position and vocabulary index
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::Int64;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = DistilBertConfig::from_file(config_path);
|
||||
///# let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
@ -166,7 +378,13 @@ impl DistilBertModelMaskedLM {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// # DistilBERT for question answering
|
||||
/// Extractive question-answering model based on a DistilBERT language model. Identifies the segment of a context that answers a provided question.
|
||||
/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
|
||||
/// See the question answering pipeline (also provided in this crate) for more details.
|
||||
/// It is made of the following blocks:
|
||||
/// - `distil_bert_model`: Base DistilBertModel
|
||||
/// - `qa_outputs`: Linear layer for question answering
|
||||
pub struct DistilBertForQuestionAnswering {
|
||||
distil_bert_model: DistilBertModel,
|
||||
qa_outputs: nn::Linear,
|
||||
@ -174,6 +392,28 @@ pub struct DistilBertForQuestionAnswering {
|
||||
}
|
||||
|
||||
impl DistilBertForQuestionAnswering {
|
||||
/// Build a new `DistilBertForQuestionAnswering` for sequence classification
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, config.num_labels, Default::default());
|
||||
@ -183,6 +423,50 @@ impl DistilBertForQuestionAnswering {
|
||||
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
|
||||
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::Int64;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = DistilBertConfig::from_file(config_path);
|
||||
///# let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (start_scores, end_score, _, _) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
input: Option<Tensor>,
|
||||
mask: Option<Tensor>,
|
||||
@ -208,6 +492,12 @@ impl DistilBertForQuestionAnswering {
|
||||
}
|
||||
}
|
||||
|
||||
/// # DistilBERT for token classification (e.g. NER, POS)
|
||||
/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
|
||||
/// not necessarily aligned with words in the sentence.
|
||||
/// It is made of the following blocks:
|
||||
/// - `distil_bert_model`: Base DistilBertModel
|
||||
/// - `classifier`: Linear layer for token classification
|
||||
pub struct DistilBertForTokenClassification {
|
||||
distil_bert_model: DistilBertModel,
|
||||
classifier: nn::Linear,
|
||||
@ -215,6 +505,28 @@ pub struct DistilBertForTokenClassification {
|
||||
}
|
||||
|
||||
impl DistilBertForTokenClassification {
|
||||
/// Build a new `DistilBertForTokenClassification` for sequence classification
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "bert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
||||
let distil_bert_model = DistilBertModel::new(&p, config);
|
||||
let classifier = nn::linear(&(p / "classifier"), config.dim, config.num_labels, Default::default());
|
||||
@ -223,6 +535,49 @@ impl DistilBertForTokenClassification {
|
||||
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) representing the logits for position and class
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::Int64;
|
||||
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = DistilBertConfig::from_file(config_path);
|
||||
///# let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length) = (64, 128);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
/// None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
|
||||
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
|
||||
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
|
||||
|
@ -1,6 +1,52 @@
|
||||
//! # DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (Sanh et al.)
|
||||
//!
|
||||
//! Implementation of the DilstilBERT language model ([https://arxiv.org/abs/1910.01108](https://arxiv.org/abs/1910.01108) Sanh, Debut, Chaumond, Wolf, 2019).
|
||||
//! The base model is implemented in the `distilbert::DistilBertModel` struct. Several language model heads have also been implemented, including:
|
||||
//! - Masked language model: `distilbert::DistilBertForMaskedLM`
|
||||
//! - Question answering: `distilbert::DistilBertForQuestionAnswering`
|
||||
//! - Sequence classification: `distilbert::DistilBertForSequenceClassification`
|
||||
//! - Token classification (e.g. NER, POS tagging): `distilbert::DistilBertForTokenClassification`
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`.
|
||||
//! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
//! - `BertTokenizer` using a `vocab.txt` vocabulary
|
||||
//!
|
||||
//! ```no_run
|
||||
//!# fn main() -> failure::Fallible<()> {
|
||||
//!#
|
||||
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
//!# home.push("rustbert");
|
||||
//!# home.push("distilbert");
|
||||
//!# let config_path = &home.as_path().join("config.json");
|
||||
//!# let vocab_path = &home.as_path().join("vocab.txt");
|
||||
//!# let weights_path = &home.as_path().join("model.ot");
|
||||
//! use rust_tokenizers::BertTokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//!# use std::path::PathBuf;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
|
||||
//! let config = DistilBertConfig::from_file(config_path);
|
||||
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
//!
|
||||
//!# Ok(())
|
||||
//!# }
|
||||
//! ```
|
||||
|
||||
|
||||
|
||||
mod distilbert;
|
||||
mod embeddings;
|
||||
mod attention;
|
||||
mod transformer;
|
||||
|
||||
pub use distilbert::{DistilBertConfig, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelClassifier};
|
||||
pub use distilbert::{DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelClassifier};
|
||||
|
Loading…
Reference in New Issue
Block a user