Completed documentation for DistilBERT

This commit is contained in:
Guillaume B 2020-03-24 19:59:00 +01:00
parent 6c6e5526ec
commit 496f7ba0bb
5 changed files with 414 additions and 14 deletions

View File

@ -19,7 +19,6 @@ use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer, Vocab};
use failure::err_msg;
use rust_bert::Config;
use rust_bert::bert::{BertConfig, BertForMaskedLM};
use tch::kind::Kind::Int64;
fn main() -> failure::Fallible<()> {

View File

@ -812,7 +812,7 @@ impl BertForQuestionAnswering {
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
///
/// let (start_positions, end_positions, _, _) = no_grad(|| {
/// let (start_scores, end_scores, _, _) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),

View File

@ -1,6 +1,6 @@
//! # BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.)
//!
//! Implementation of the BERT language model (https://arxiv.org/abs/1810.04805 Devlin, Chang, Lee, Toutanova, 2018).
//! Implementation of the BERT language model ([https://arxiv.org/abs/1810.04805](https://arxiv.org/abs/1810.04805) Devlin, Chang, Lee, Toutanova, 2018).
//! The base model is implemented in the `bert::BertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `bert::BertForMaskedLM`
//! - Multiple choices: `bert:BertForMultipleChoice`
@ -14,7 +14,7 @@
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights is expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! ```no_run

View File

@ -22,12 +22,17 @@ use crate::Config;
#[allow(non_camel_case_types)]
#[derive(Debug, Serialize, Deserialize)]
/// # Activation function used in the feed-forward layer in the transformer blocks
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
gelu,
/// Rectified Linear Unit
relu,
}
#[derive(Debug, Serialize, Deserialize)]
/// # DistilBERT model configuration
/// Defines the DistilBERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct DistilBertConfig {
pub activation: Activation,
pub attention_dropout: f64,
@ -56,12 +61,40 @@ pub struct DistilBertConfig {
impl Config<DistilBertConfig> for DistilBertConfig {}
/// # DistilBERT Base model
/// Base architecture for DistilBERT models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `embeddings`: `token`, `position` embeddings
/// - `transformer`: Transformer made of a vector of layers. Each layer is made of a multi-head self-attention layer, layer norm and linear layers.
pub struct DistilBertModel {
embeddings: DistilBertEmbedding,
transformer: Transformer,
}
/// Defines the implementation of the DistilBertModel.
impl DistilBertModel {
/// Build a new `DistilBertModel`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the DistilBERT model
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
///
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
let p = &(p / "distilbert");
let embeddings = DistilBertEmbedding::new(&(p / "embeddings"), config);
@ -69,14 +102,49 @@ impl DistilBertModel {
DistilBertModel { embeddings, transformer }
}
pub fn _get_embeddings(&self) -> &nn::Embedding {
self.embeddings._get_word_embeddings()
}
pub fn _set_embeddings(&mut self, new_embeddings: nn::Embedding) {
&self.embeddings._set_word_embeddings(new_embeddings);
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModel};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModel = DistilBertModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, _, _) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let input_embeddings = match input {
@ -96,6 +164,12 @@ impl DistilBertModel {
}
}
/// # DistilBERT for sequence classification
/// Base DistilBERT model with a pre-classifier and classifier heads to perform sentence or document-level classification
/// It is made of the following blocks:
/// - `distil_bert_model`: Base DistilBertModel
/// - `pre_classifier`: DistilBERT linear layer for classification
/// - `classifier`: DistilBERT linear layer for classification
pub struct DistilBertModelClassifier {
distil_bert_model: DistilBertModel,
pre_classifier: nn::Linear,
@ -104,6 +178,28 @@ pub struct DistilBertModelClassifier {
}
impl DistilBertModelClassifier {
/// Build a new `DistilBertModelClassifier` for sequence classification
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the DistilBERT model
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
///
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
let distil_bert_model = DistilBertModel::new(&p, config);
let pre_classifier = nn::linear(&(p / "pre_classifier"), config.dim, config.dim, Default::default());
@ -113,6 +209,49 @@ impl DistilBertModelClassifier {
DistilBertModelClassifier { distil_bert_model, pre_classifier, classifier, dropout }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*) representing the logits for each class to predict
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelClassifier};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model: DistilBertModelClassifier = DistilBertModelClassifier::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, _, _) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
@ -131,6 +270,13 @@ impl DistilBertModelClassifier {
}
}
/// # DistilBERT for masked language model
/// Base DistilBERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
/// It is made of the following blocks:
/// - `distil_bert_model`: Base DistilBertModel
/// - `vocab_transform`:linear layer for classification of size (*hidden_dim*, *hidden_dim*)
/// - `vocab_layer_norm`: layer normalization
/// - `vocab_projector`: linear layer for classification of size (*hidden_dim*, *vocab_size*) with weights tied to the token embeddings
pub struct DistilBertModelMaskedLM {
distil_bert_model: DistilBertModel,
vocab_transform: nn::Linear,
@ -138,7 +284,30 @@ pub struct DistilBertModelMaskedLM {
vocab_projector: nn::Linear,
}
impl DistilBertModelMaskedLM {
/// Build a new `DistilBertModelMaskedLM` for sequence classification
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the DistilBERT model
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
///
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
let distil_bert_model = DistilBertModel::new(&p, config);
let vocab_transform = nn::linear(&(p / "vocab_transform"), config.dim, config.dim, Default::default());
@ -149,6 +318,49 @@ impl DistilBertModelMaskedLM {
DistilBertModelMaskedLM { distil_bert_model, vocab_transform, vocab_layer_norm, vocab_projector }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for position and vocabulary index
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertModelMaskedLM};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, _, _) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {
@ -166,7 +378,13 @@ impl DistilBertModelMaskedLM {
}
}
/// # DistilBERT for question answering
/// Extractive question-answering model based on a DistilBERT language model. Identifies the segment of a context that answers a provided question.
/// Please note that a significant amount of pre- and post-processing is required to perform end-to-end question answering.
/// See the question answering pipeline (also provided in this crate) for more details.
/// It is made of the following blocks:
/// - `distil_bert_model`: Base DistilBertModel
/// - `qa_outputs`: Linear layer for question answering
pub struct DistilBertForQuestionAnswering {
distil_bert_model: DistilBertModel,
qa_outputs: nn::Linear,
@ -174,6 +392,28 @@ pub struct DistilBertForQuestionAnswering {
}
impl DistilBertForQuestionAnswering {
/// Build a new `DistilBertForQuestionAnswering` for sequence classification
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the DistilBERT model
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
///
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
let distil_bert_model = DistilBertModel::new(&p, config);
let qa_outputs = nn::linear(&(p / "qa_outputs"), config.dim, config.num_labels, Default::default());
@ -183,6 +423,50 @@ impl DistilBertForQuestionAnswering {
DistilBertForQuestionAnswering { distil_bert_model, qa_outputs, dropout }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `start_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for start of the answer
/// * `end_scores` - `Tensor` of shape (*batch size*, *sequence_length*) containing the logits for end of the answer
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForQuestionAnswering};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForQuestionAnswering::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (start_scores, end_score, _, _) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self,
input: Option<Tensor>,
mask: Option<Tensor>,
@ -208,6 +492,12 @@ impl DistilBertForQuestionAnswering {
}
}
/// # DistilBERT for token classification (e.g. NER, POS)
/// Token-level classifier predicting a label for each token provided. Note that because of wordpiece tokenization, the labels predicted are
/// not necessarily aligned with words in the sentence.
/// It is made of the following blocks:
/// - `distil_bert_model`: Base DistilBertModel
/// - `classifier`: Linear layer for token classification
pub struct DistilBertForTokenClassification {
distil_bert_model: DistilBertModel,
classifier: nn::Linear,
@ -215,6 +505,28 @@ pub struct DistilBertForTokenClassification {
}
impl DistilBertForTokenClassification {
/// Build a new `DistilBertForTokenClassification` for sequence classification
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the DistilBERT model
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
///
/// # Example
///
/// ```no_run
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "bert"), &config);
/// ```
///
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
let distil_bert_model = DistilBertModel::new(&p, config);
let classifier = nn::linear(&(p / "classifier"), config.dim, config.num_labels, Default::default());
@ -223,6 +535,49 @@ impl DistilBertForTokenClassification {
DistilBertForTokenClassification { distil_bert_model, classifier, dropout }
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *num_labels*) representing the logits for position and class
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use tch::kind::Kind::Int64;
/// use rust_bert::distilbert::{DistilBertConfig, DistilBertForTokenClassification};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let config = DistilBertConfig::from_file(config_path);
///# let distilbert_model = DistilBertForTokenClassification::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
///
/// let (output, _, _) = no_grad(|| {
/// distilbert_model
/// .forward_t(Some(input_tensor),
/// Some(mask),
/// None,
/// false).unwrap()
/// });
///
/// ```
///
pub fn forward_t(&self, input: Option<Tensor>, mask: Option<Tensor>, input_embeds: Option<Tensor>, train: bool)
-> Result<(Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>), &'static str> {
let (output, all_hidden_states, all_attentions) = match self.distil_bert_model.forward_t(input, mask, input_embeds, train) {

View File

@ -1,6 +1,52 @@
//! # DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (Sanh et al.)
//!
//! Implementation of the DilstilBERT language model ([https://arxiv.org/abs/1910.01108](https://arxiv.org/abs/1910.01108) Sanh, Debut, Chaumond, Wolf, 2019).
//! The base model is implemented in the `distilbert::DistilBertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `distilbert::DistilBertForMaskedLM`
//! - Question answering: `distilbert::DistilBertForQuestionAnswering`
//! - Sequence classification: `distilbert::DistilBertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `distilbert::DistilBertForTokenClassification`
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example is provided in `examples/distilbert_masked_lm.rs`, run with `cargo run --example distilbert_masked_lm`.
//! The example below illustrate a DistilBERT Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//!
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!#
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("distilbert");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::Config;
//! use rust_bert::distilbert::{DistilBertModelMaskedLM, DistilBertConfig};
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
//! let config = DistilBertConfig::from_file(config_path);
//! let bert_model = DistilBertModelMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//!# Ok(())
//!# }
//! ```
mod distilbert;
mod embeddings;
mod attention;
mod transformer;
pub use distilbert::{DistilBertConfig, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelClassifier};
pub use distilbert::{DistilBertConfig, Activation, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForTokenClassification, DistilBertModelMaskedLM, DistilBertModelClassifier};