initial docs for XLNet

This commit is contained in:
Guillaume B 2020-10-01 19:13:04 +02:00
parent b0d04da8f5
commit 6dde3241a2
4 changed files with 141 additions and 2 deletions

View File

@ -39,7 +39,7 @@ fn main() -> anyhow::Result<()> {
config_resource,
vocab_resource,
merges_resource,
max_length: 56,
max_length: 32,
do_sample: true,
num_beams: 3,
temperature: 1.0,

View File

@ -246,7 +246,7 @@ impl T5Model {
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing th elast calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns

View File

@ -1,3 +1,60 @@
//! # XLNet (Generalized Autoregressive Pretraining for Language Understanding)
//!
//! Implementation of the XLNet language model ([Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) Yang, Dai, Yang, Carbonell, Salakhutdinov, Le, 2019).
//! The base model is implemented in the `xlnet::XLNetModel` struct. Several language model heads have also been implemented, including:
//! - Language generation: `xlnet::XLNetLMHeadModel` implementing the common `generation::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information)
//! - Multiple choices: `xlnet:XLNetForMultipleChoice`
//! - Question answering: `xlnet::XLNetForQuestionAnswering`
//! - Sequence classification: `xlnet::XLNetForSequenceClassification`
//! - Token classification (e.g. NER, POS tagging): `xlnet::XLNetForTokenClassification`.
//!
//! # Model set-up and pre-trained weights loading
//!
//! A full working example (generation) is provided in `examples/generation_xlnet`, run with `cargo run --example generation_xlnet`.
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `XLNetTokenizer` using a `spiece.model` sentence piece model
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! use rust_bert::resources::{Resource, RemoteResource};
//! use rust_bert::xlnet::{XLNetConfigResources, XLNetVocabResources, XLNetModelResources};
//! use rust_bert::pipelines::generation::{GenerateConfig, XLNetGenerator, LanguageGenerator};
//! let config_resource = Resource::Remote(RemoteResource::from_pretrained(
//! XLNetConfigResources::XLNET_BASE_CASED,
//! ));
//! let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
//! XLNetVocabResources::XLNET_BASE_CASED,
//! ));
//! let model_resource = Resource::Remote(RemoteResource::from_pretrained(
//! XLNetModelResources::XLNET_BASE_CASED,
//! ));
//! let generate_config = GenerateConfig {
//! model_resource,
//! config_resource,
//! vocab_resource,
//! merges_resource,
//! max_length: 56,
//! do_sample: true,
//! num_beams: 3,
//! temperature: 1.0,
//! num_return_sequences: 1,
//! ..Default::default()
//! };
//! let model = XLNetGenerator::new(generate_config)?;
//! let input_context = "Once upon a time,";
//! let output = model.generate(Some(vec![input_context]), None);
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod encoder;
mod xlnet;

View File

@ -110,6 +110,12 @@ pub struct XLNetConfig {
impl Config<XLNetConfig> for XLNetConfig {}
/// # XLNet Base model
/// Base architecture for XLNet models. Task-specific models will be built from this common base model
/// It is made of the following blocks:
/// - `word_embeddings`: Word embeddings
/// - `mask_emb`: Embedding for the masked tokens (`g` states)
/// - `layers`: Vector of `XLNetLayer`. Each layer is made of a self-attention layers on the visible and hidden states and a post-attention layer
pub struct XLNetModel {
mem_len: Option<i64>,
reuse_len: Option<i64>,
@ -128,6 +134,27 @@ pub struct XLNetModel {
}
impl XLNetModel {
/// Build a new `XLNetModel`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BART model
/// * `config` - `XLNetConfig` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
/// use rust_bert::xlnet::{XLNetConfig, XLNetModel};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = XLNetConfig::from_file(config_path);
/// let xlnet_model = XLNetModel::new(&p.root(), &config);
/// ```
pub fn new<'p, P>(p: P, config: &XLNetConfig) -> XLNetModel
where
P: Borrow<nn::Path<'p>>,
@ -292,6 +319,61 @@ impl XLNetModel {
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `perm_mask` - Optional tensor of shape (*batch size*, *sequence_length*, *sequence_length*). Mask to indicate the attention pattern for each input token (only used for pre-training over permutations, rather than simple token masking).
/// * `target_mapping ` - Optional tensor of shape (*batch size*, *num_tokens*, *sequence_length*) indicating the position of the masked words to predict.
/// * `token_type_ids` - Optional tensor (*batch size*, *sequence_length*) indicating the sentence ID of the token (0: first sentence, 1: second sentence).
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `old_layer_states` - Optional vector of length `num_layers` containing optional `LayerStates` containing the last calculated key and value pairs for the attention layers. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `XLNetModelOutput` containing:
/// - `hidden_state` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
/// - `next_cache` - `Option<Vec<Option<LayerState>>>` of length *n_layer* containing the past keys and values for both the attention layers with shape (*past_sequence_length*, *batch size*, *hidden_size*)
/// - `all_hidden_states` - `Option<Vec<(Tensor, Option<Tensor>)>>` of length *n_layer* with shape (*batch size*, *sequence_length*, *hidden_size*) (with optional embedding states if used)
/// - `all_attentions` - `Option<Vec<(Tensor, Option<Tensor>)>>` of length *n_layer* with shape (*batch size*, *sequence_length*, *hidden_size*) (with optional embedding states if used)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad, Kind};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::xlnet::{XLNetConfig, XLNetModel};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = XLNetConfig::from_file(config_path);
/// # let xlnet_model: XLNetModel = XLNetModel::new(&vs.root(), &config);
/// let (batch_size, sequence_length) = (64, 128);
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
/// let attention_mask = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let target_tensor = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
/// let target_mapping = Tensor::zeros(&[64, 1, 128], (Kind::Float, device));
/// let _ = target_mapping.narrow(2, 3, 1).fill_(1.0);
///
/// let model_output = no_grad(|| {
/// xlnet_model.forward_t(
/// Some(&input_tensor),
/// Some(&attention_mask),
/// None,
/// Some(&target_mapping),
/// None,
/// None,
/// None,
/// false
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,