BERT documentation (ongoing)

This commit is contained in:
Guillaume B 2020-03-22 11:56:58 +01:00
parent 54b10c54a7
commit e1f3b743da
4 changed files with 332 additions and 19 deletions

View File

@ -46,23 +46,16 @@ fn main() -> failure::Fallible<()> {
// Define input
let input = ["Looks like one thing is missing", "It\'s like comparing oranges to apples"];
let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let mut tokenized_input = tokenized_input.
let tokenized_input = tokenized_input.
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
// Masking the token [thing] of sentence 1 and [oranges] of sentence 2
tokenized_input[0][4] = 103;
tokenized_input[1][6] = 103;
let tokenized_input = tokenized_input.
@ -83,12 +76,12 @@ fn main() -> failure::Fallible<()> {
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(6).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
println!("{}", word_1); // Outputs "person" : "Looks like one [person] is missing"
println!("{}", word_2);// Outputs "pear" : "It\'s like comparing [pear] to apples"
println!("{}", word_2);// Outputs "pear" : "It was a very nice and [pleasant] day"

View File

@ -25,13 +25,19 @@ use crate::Config;
#[derive(Debug, Serialize, Deserialize)]
/// # Activation function used in the attention layer and masked language model head
pub enum Activation {
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](
/// Rectified Linear Unit
/// Mish ([Misra, 2019](
#[derive(Debug, Serialize, Deserialize)]
/// # BERT model configuration
/// Defines the BERT model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct BertConfig {
pub hidden_act: Activation,
pub attention_probs_dropout_prob: f64,
@ -54,6 +60,13 @@ pub struct BertConfig {
impl Config<BertConfig> for BertConfig {}
/// # BERT Base model
/// Base architecture for BERT models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `embeddings`: `token`, `position` and `segment_id` embeddings
/// - `encoder`: Encoder (transformer) made of a vector of layers. Each layer is made of a self-attention layer, an intermediate (linear) and output (linear + layer norm) layers
/// - `pooler`: linear layer applied to the first element of the sequence (*[MASK]* token)
/// - `is_decoder`: Flag indicating if the model is used as a decoder. If set to true, a causal mask will be applied to hide future positions that should not be attended to.
pub struct BertModel<T: BertEmbedding> {
embeddings: T,
encoder: BertEncoder,
@ -61,7 +74,32 @@ pub struct BertModel<T: BertEmbedding> {
is_decoder: bool,
impl <T: BertEmbedding> BertModel<T> {
/// Defines the implementation of the BertModel. The BERT model shares many similarities with RoBERTa, main difference being the embeddings.
/// Therefore the forward pass of the model is shared and the type of embedding used is abstracted away. This allows to create
/// `BertModel<RobertaEmbeddings>` or `BertModel<BertEmbeddings>` for each model type.
impl<T: BertEmbedding> BertModel<T> {
/// Build a new `BertModel`
/// # Arguments
/// * `p` - Variable store path for the root of the BERT model
/// * `config` - `BertConfig` object defining the model architecture and decoder status
/// # Example
/// ```no_run
/// use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert: BertModel<BertEmbeddings> = BertModel::new(&(&p.root() / "bert"), &config);
/// ```
pub fn new(p: &nn::Path, config: &BertConfig) -> BertModel<T> {
let is_decoder = match config.is_decoder {
Some(value) => value,
@ -74,6 +112,73 @@ impl <T: BertEmbedding> BertModel<T> {
BertModel { embeddings, encoder, pooler, is_decoder }
/// Forward pass through the model
/// # Arguments
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used in the cross-attention layer as keys and values (query from the decoder).
/// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the `encoder_hidden_states` is not None, used to mask encoder values. Positions with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
/// # Returns
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `pooled_output` - `Tensor` of shape (*batch size*, *hidden_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// # Example
/// ```no_run
///# use rust_bert::bert::{BertModel, BertConfig, BertEmbeddings};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
///# let config = BertConfig::from_file(config_path);
///# let bert_model: BertModel<BertEmbeddings> = BertModel::new(&vs.root(), &config);
/// let input = ["One sentence", "Another sentence"];
/// let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
/// let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
/// let mut tokenized_input = tokenized_input
/// .iter()
/// // retrieve input ids from TokenizedInput
/// .map(|input| input.token_ids.clone())
/// // Padding inputs to same length
/// .map(|mut input| {
/// input.extend(vec![0; max_len - input.len()]);
/// input
/// })
/// // Map to Tensor
/// .map(|input|
/// Tensor::of_slice(&(input)))
/// .collect::<Vec<_>>();
/// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
/// let (output, _, _, _) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// None,
/// None,
/// None,
/// None,
/// &None,
/// &None,
/// false).unwrap()
/// });
/// ```
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
@ -196,12 +301,39 @@ impl BertLMPredictionHead {
/// # BERT for masked language model
/// Base BERT model with a masked language model head to predict missing tokens, for example `"Looks like one [MASK] is missing" -> "person"`
/// It is made of the following blocks:
/// - `bert`: Base BertModel
/// - `cls`: BERT LM prediction head
pub struct BertForMaskedLM {
bert: BertModel<BertEmbeddings>,
cls: BertLMPredictionHead,
impl BertForMaskedLM {
/// Build a new `BertForMaskedLM`
/// # Arguments
/// * `p` - Variable store path for the root of the BertForMaskedLM model
/// * `config` - `BertConfig` object defining the model architecture and vocab size
/// # Example
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForMaskedLM};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForMaskedLM::new(&(&p.root() / "bert"), &config);
/// ```
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForMaskedLM {
let bert = BertModel::new(&(p / "bert"), config);
let cls = BertLMPredictionHead::new(&(p / "cls"), config);
@ -209,6 +341,74 @@ impl BertForMaskedLM {
BertForMaskedLM { bert, cls }
/// Forward pass through the model
/// # Arguments
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see *input_embeds*)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see *input_ids*)
/// * `encoder_hidden_states` - Optional encoder hidden state of shape (*batch size*, *encoder_sequence_length*, *hidden_size*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used in the cross-attention layer as keys and values (query from the decoder).
/// * `encoder_mask` - Optional encoder attention mask of shape (*batch size*, *encoder_sequence_length*). If the model is defined as a decoder and the *encoder_hidden_states* is not None, used to mask encoder values. Positions with value 0 will be masked.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
/// # Returns
/// * `output` - `Tensor` of shape (*batch size*, *num_labels*, *vocab_size*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// # Example
/// ```no_run
///# use rust_bert::bert::{BertModel, BertConfig, BertForMaskedLM};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
///# let config = BertConfig::from_file(config_path);
///# let bert_model = BertForMaskedLM::new(&vs.root(), &config);
/// let input = ["Looks like one [MASK] is missing", "It was a very nice and [MASK] day"];
/// let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
/// let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
/// let mut tokenized_input = tokenized_input
/// .iter()
/// // retrieve input ids from TokenizedInput
/// .map(|input| input.token_ids.clone())
/// // Padding inputs to same length
/// .map(|mut input| {
/// input.extend(vec![0; max_len - input.len()]);
/// input
/// })
/// // Map to Tensor
/// .map(|input|
/// Tensor::of_slice(&(input)))
/// .collect::<Vec<_>>();
/// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
/// let (output, _, _) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// None,
/// None,
/// None,
/// None,
/// &None,
/// &None,
/// false)
/// });
/// ```
/// Outputs: `Looks like one [person] is missing` and `It was a very nice and [pleasant] day`.
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
@ -226,6 +426,11 @@ impl BertForMaskedLM {
/// # BERT for sequence classification
/// Base BERT model with a classifier head to perform sentence or document-level classification
/// It is made of the following blocks:
/// - `bert`: Base BertModel
/// - `classifier`: BERT linear layer for classification
pub struct BertForSequenceClassification {
bert: BertModel<BertEmbeddings>,
dropout: Dropout,
@ -233,6 +438,27 @@ pub struct BertForSequenceClassification {
impl BertForSequenceClassification {
/// Build a new `BertForSequenceClassification`
/// # Arguments
/// * `p` - Variable store path for the root of the BertForMaskedLM model
/// * `config` - `BertConfig` object defining the model architecture and number of classes
/// # Example
/// ```no_run
/// use rust_bert::bert::{BertConfig, BertForSequenceClassification};
/// use tch::{nn, Device};
/// use rust_bert::Config;
/// use std::path::Path;
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForSequenceClassification::new(&(&p.root() / "bert"), &config);
/// ```
pub fn new(p: &nn::Path, config: &BertConfig) -> BertForSequenceClassification {
let bert = BertModel::new(&(p / "bert"), config);
let dropout = Dropout::new(config.hidden_dropout_prob);
@ -242,6 +468,68 @@ impl BertForSequenceClassification {
BertForSequenceClassification { bert, dropout, classifier }
/// Forward pass through the model
/// # Arguments
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length`). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `token_type_ids` -Optional segment id of shape (*batch size*, *sequence_length*). Convention is value of 0 for the first sentence (incl. *[SEP]*) and 1 for the second sentence. If None set to 0.
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented from 0.
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
/// # Returns
/// * `labels` - `Tensor` of shape (*batch size*, *num_labels*)
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
/// # Example
/// ```no_run
///# use rust_bert::bert::{BertModel, BertConfig, BertForSequenceClassification};
///# use tch::{nn, Device, Tensor, no_grad};
///# use rust_bert::Config;
///# use std::path::Path;
///# use rust_tokenizers::{BertTokenizer, TruncationStrategy, Tokenizer};
///# let config_path = Path::new("path/to/config.json");
///# let vocab_path = Path::new("path/to/vocab.txt");
///# let device = Device::Cpu;
///# let vs = nn::VarStore::new(device);
///# let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
///# let config = BertConfig::from_file(config_path);
///# let bert_model = BertForSequenceClassification::new(&vs.root(), &config);
/// let input = ["First sentence to classify", "Second sentence to classify"];
/// let tokenized_input = tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
/// let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
/// let mut tokenized_input = tokenized_input
/// .iter()
/// // retrieve input ids from TokenizedInput
/// .map(|input| input.token_ids.clone())
/// // Padding inputs to same length
/// .map(|mut input| {
/// input.extend(vec![0; max_len - input.len()]);
/// input
/// })
/// // Map to Tensor
/// .map(|input|
/// Tensor::of_slice(&(input)))
/// .collect::<Vec<_>>();
/// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
/// let (output, _, _) = no_grad(|| {
/// bert_model
/// .forward_t(Some(input_tensor),
/// None,
/// None,
/// None,
/// None,
/// false)
/// });
/// ```
pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,

View File

@ -1,16 +1,47 @@
//! # BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
//! # BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al.)
//! Implementation of the BERT language model ( Devlin, Chang, Lee, Toutanova, 2018).
//! The base model is implemented in the `bert::BertModel` struct. Several language model heads have also been implemented, including:
//! - Masked language model: `bert::BertForMaskedLM`
//! - Multiple choices: `bert:BertForMultipleChoice`
//! - Question answering: ``bert::BertForQuestionAnswering`
//! -
//! - Question answering: `bert::BertForQuestionAnswering`
//! - Sequence classification: `bert::BertForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `bert::BertForTokenClassification`
//! # Model set-up and pre-trained weights loading
//! A full working example is provided in `examples/`, run with `cargo run --example bert`.
//! The example below illustrate a Masked language model example, the structure is similar for other models.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](
//! - Model weights is expected to have a structure and parameter names following the [Transformers library]( A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `BertTokenizer` using a `vocab.txt` vocabulary
//! # Quick Start
//! ```no_run
//!# fn main() -> failure::Fallible<()> {
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
//!# home.push("rustbert");
//!# home.push("bert");
//!# let config_path = &home.as_path().join("config.json");
//!# let vocab_path = &home.as_path().join("vocab.txt");
//!# let weights_path = &home.as_path().join("model.ot");
//! use rust_tokenizers::BertTokenizer;
//! use tch::{nn, Device};
//!# use std::path::PathBuf;
//! use rust_bert::bert::{BertForMaskedLM, BertConfig};
//! use rust_bert::Config;
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
//! let config = BertConfig::from_file(config_path);
//! let bert_model = BertForMaskedLM::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!# Ok(())
//!# }
//! ```
mod bert;
@ -18,5 +49,5 @@ mod embeddings;
mod attention;
mod encoder;
pub use bert::{BertConfig, BertModel, BertForTokenClassification, BertForMultipleChoice, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
pub(crate) use embeddings::BertEmbedding;
pub use bert::{BertConfig, Activation, BertModel, BertForTokenClassification, BertForMultipleChoice, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering};
pub use embeddings::{BertEmbedding, BertEmbeddings};

View File

@ -1,4 +1,5 @@
mod embeddings;
mod roberta;
pub use roberta::{RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
pub use roberta::{RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForSequenceClassification};
pub use embeddings::RobertaEmbeddings;