mirror of
https://github.com/guillaume-be/rust-bert.git
synced 2024-11-09 17:05:51 +03:00
Completed documentation for GPT2
This commit is contained in:
parent
c1e0823ee1
commit
2263b9fb86
@ -152,7 +152,7 @@ impl<T: BertEmbedding> BertModel<T> {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, _, _) = no_grad(|| {
|
||||
/// let (output, pooled_output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -367,7 +367,7 @@ impl BertForMaskedLM {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -477,7 +477,7 @@ impl BertForSequenceClassification {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (labels, _, _) = no_grad(|| {
|
||||
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -583,7 +583,7 @@ impl BertForMultipleChoice {
|
||||
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
|
||||
///
|
||||
/// let (choices, _, _) = no_grad(|| {
|
||||
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(input_tensor,
|
||||
/// Some(mask),
|
||||
@ -705,7 +705,7 @@ impl BertForTokenClassification {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (token_labels, _, _) = no_grad(|| {
|
||||
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -812,7 +812,7 @@ impl BertForQuestionAnswering {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (start_scores, end_scores, _, _) = no_grad(|| {
|
||||
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// bert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
|
@ -78,7 +78,7 @@ impl DistilBertModel {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBERT model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -92,7 +92,7 @@ impl DistilBertModel {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "bert"), &config);
|
||||
/// let distil_bert: DistilBertModel = DistilBertModel::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModel {
|
||||
@ -135,7 +135,7 @@ impl DistilBertModel {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -183,7 +183,7 @@ impl DistilBertModelClassifier {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBertModelClassifier model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -197,7 +197,7 @@ impl DistilBertModelClassifier {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "bert"), &config);
|
||||
/// let distil_bert: DistilBertModelClassifier = DistilBertModelClassifier::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelClassifier {
|
||||
@ -242,7 +242,7 @@ impl DistilBertModelClassifier {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -291,7 +291,7 @@ impl DistilBertModelMaskedLM {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBertModelMaskedLM model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -305,7 +305,7 @@ impl DistilBertModelMaskedLM {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "bert"), &config);
|
||||
/// let distil_bert = DistilBertModelMaskedLM::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertModelMaskedLM {
|
||||
@ -351,7 +351,7 @@ impl DistilBertModelMaskedLM {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -397,7 +397,7 @@ impl DistilBertForQuestionAnswering {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBertForQuestionAnswering model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -411,7 +411,7 @@ impl DistilBertForQuestionAnswering {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "bert"), &config);
|
||||
/// let distil_bert = DistilBertForQuestionAnswering::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForQuestionAnswering {
|
||||
@ -457,7 +457,7 @@ impl DistilBertForQuestionAnswering {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (start_scores, end_score, _, _) = no_grad(|| {
|
||||
/// let (start_scores, end_score, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -510,7 +510,7 @@ impl DistilBertForTokenClassification {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the DistilBertForTokenClassification model
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture and decoder status
|
||||
/// * `config` - `DistilBertConfig` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
@ -524,7 +524,7 @@ impl DistilBertForTokenClassification {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = DistilBertConfig::from_file(config_path);
|
||||
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "bert"), &config);
|
||||
/// let distil_bert = DistilBertForTokenClassification::new(&(&p.root() / "distilbert"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &DistilBertConfig) -> DistilBertForTokenClassification {
|
||||
@ -568,7 +568,7 @@ impl DistilBertForTokenClassification {
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// distilbert_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
|
@ -1,6 +1,6 @@
|
||||
//! # DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (Sanh et al.)
|
||||
//!
|
||||
//! Implementation of the DilstilBERT language model ([https://arxiv.org/abs/1910.01108](https://arxiv.org/abs/1910.01108) Sanh, Debut, Chaumond, Wolf, 2019).
|
||||
//! Implementation of the DistilBERT language model ([https://arxiv.org/abs/1910.01108](https://arxiv.org/abs/1910.01108) Sanh, Debut, Chaumond, Wolf, 2019).
|
||||
//! The base model is implemented in the `distilbert::DistilBertModel` struct. Several language model heads have also been implemented, including:
|
||||
//! - Masked language model: `distilbert::DistilBertForMaskedLM`
|
||||
//! - Question answering: `distilbert::DistilBertForQuestionAnswering`
|
||||
|
180
src/gpt2/gpt2.rs
180
src/gpt2/gpt2.rs
@ -24,13 +24,20 @@ use crate::Config;
|
||||
|
||||
#[allow(non_camel_case_types)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # Activation function used in the fully connected layers of the transformer block
|
||||
pub enum GptActivation {
|
||||
/// Gaussian Error Linear Unit ([Hendrycks et al., 2016,](https://arxiv.org/abs/1606.08415))
|
||||
gelu,
|
||||
/// Rectified Linear Unit
|
||||
relu,
|
||||
/// Swish: a Self-Gated Activation Function ([Ramachandran et al., 2017](https://arxiv.org/pdf/1710.05941v1.pdf))
|
||||
swish,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
/// # GPT2 model configuration
|
||||
/// Defines the GPT2 model architecture (e.g. number of layers, hidden layer size, vocab size...).
|
||||
/// Shared between GPT and GPT2 models
|
||||
pub struct Gpt2Config {
|
||||
pub attn_pdrop: Option<f64>,
|
||||
pub embd_pdrop: Option<f64>,
|
||||
@ -53,6 +60,15 @@ pub struct Gpt2Config {
|
||||
|
||||
impl Config<Gpt2Config> for Gpt2Config {}
|
||||
|
||||
/// # GPT2 Base model
|
||||
/// Base architecture for GPT2 model. Usually complemented with a task-specific head, such as a language model head.
|
||||
/// It is made of the following blocks:
|
||||
/// - `wte`: `token` embeddings
|
||||
/// - `wpe`: `position` embeddings
|
||||
/// - `h`: Encoder (transformer) made of a vector of layers. Each layer is made of a multi-head attention layer, layer-normalization layers and a MLP made of linear layers.
|
||||
/// - `output_past`: flag indicating if the model should return a past state. This can be fed back to the model to improve the quality of text generated.
|
||||
/// - `output_hidden_states`: flag indicating if the model should return all hidden states (as opposed to only the last layer)
|
||||
/// - `output_attentions`: flag indicating if the model should return activation weights
|
||||
pub struct Gpt2Model {
|
||||
wte: nn::Embedding,
|
||||
wpe: nn::Embedding,
|
||||
@ -65,6 +81,28 @@ pub struct Gpt2Model {
|
||||
}
|
||||
|
||||
impl Gpt2Model {
|
||||
/// Build a new `Gpt2Model`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BERT model
|
||||
/// * `config` - `Gpt2Config` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::{Gpt2Config, Gpt2Model};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: Gpt2Model = Gpt2Model::new(&(&p.root() / "gpt2"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> Gpt2Model {
|
||||
let p = &(p / "transformer");
|
||||
let wte = embedding(&(p / "wte"), config.vocab_size, config.n_embd, Default::default());
|
||||
@ -97,6 +135,62 @@ impl Gpt2Model {
|
||||
Gpt2Model { wte, wpe, drop, ln_f, h, output_past, output_hidden_states, output_attentions }
|
||||
}
|
||||
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `layer_past` - Optional vector of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
|
||||
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
|
||||
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *hidden_size*) representing the activations of the last hidden state
|
||||
/// * `past` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = Gpt2Config::from_file(config_path);
|
||||
///# let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||
/// for _ in 0..config.n_layer as usize {
|
||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
||||
/// }
|
||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// &Some(past),
|
||||
/// &Some(attention_mask),
|
||||
/// &Some(token_type_ids),
|
||||
/// &Some(position_ids),
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
pub fn forward_t(&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
layer_past: &Option<Vec<Tensor>>,
|
||||
@ -182,13 +276,39 @@ impl Gpt2Model {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// # GPT2 Language Modeling head
|
||||
/// GPT2 model with a decoding head (linear layer without bias). The weights of the linear layer are tied to the word embeddings
|
||||
/// It is made of the following blocks:
|
||||
/// - `transformer`: Base Gpt2Model
|
||||
/// - `lm_head`: Linear layer without bias tied to the weights of the token id embeddings
|
||||
pub struct GPT2LMHeadModel {
|
||||
transformer: Gpt2Model,
|
||||
lm_head: LinearNoBias,
|
||||
}
|
||||
|
||||
impl GPT2LMHeadModel {
|
||||
/// Build a new `GPT2LMHeadModel`
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `p` - Variable store path for the root of the BERT model
|
||||
/// * `config` - `Gpt2Config` object defining the model architecture
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use tch::{nn, Device};
|
||||
/// use rust_bert::Config;
|
||||
/// use std::path::Path;
|
||||
/// use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
||||
///
|
||||
/// let config_path = Path::new("path/to/config.json");
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = Gpt2Config::from_file(config_path);
|
||||
/// let gpt2: GPT2LMHeadModel = GPT2LMHeadModel::new(&(&p.root() / "gpt2"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &Gpt2Config) -> GPT2LMHeadModel {
|
||||
let transformer = Gpt2Model::new(&p, config);
|
||||
let lm_head = linear_no_bias(&(p / "lm_head"), config.n_embd, config.vocab_size, Default::default());
|
||||
@ -196,6 +316,8 @@ impl GPT2LMHeadModel {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Language Model trait
|
||||
/// Shared trait between language generation models (e.g. GPT2 and GPT) used in language generation pipelines.
|
||||
pub trait LMHeadModel {
|
||||
fn forward_t(&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
@ -208,6 +330,62 @@ pub trait LMHeadModel {
|
||||
}
|
||||
|
||||
impl LMHeadModel for GPT2LMHeadModel {
|
||||
/// Forward pass through the model
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
|
||||
/// * `layer_past` - Optional vector of size *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*). When provided, these are concatenated with the current input keys and values.
|
||||
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
|
||||
/// * `input_embeds` - Optional pre-computed input embeddings of shape (*batch size*, *sequence_length*, *hidden_size*). If None, input ids must be provided (see `input_ids`)
|
||||
/// * `token_type_ids` - Optional token type ids used to indicate the portion of the input the token belongs to. If not None, token type embeddings will be added to the token and position embeddings.
|
||||
/// * `position_ids` - Optional position ids of shape (*batch size*, *sequence_length*). If None, will be incremented starting from the length of the past input.
|
||||
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `output` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
|
||||
/// * `past` - `Option<Vec<Tensor>>` of length *n_layer* containing the past keys and values of each layer of shape (*2*, *batch size*, *number of heads*, *past_sequence_length*, *hidden size per head*)
|
||||
/// * `hidden_states` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
/// * `attentions` - `Option<Vec<Tensor>>` of length *num_hidden_layers* with shape (*batch size*, *sequence_length*, *hidden_size*)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
///# use tch::{nn, Device, Tensor, no_grad};
|
||||
///# use rust_bert::Config;
|
||||
///# use std::path::Path;
|
||||
///# use tch::kind::Kind::{Int64, Double};
|
||||
/// use rust_bert::gpt2::{Gpt2Model, Gpt2Config};
|
||||
///# let config_path = Path::new("path/to/config.json");
|
||||
///# let vocab_path = Path::new("path/to/vocab.txt");
|
||||
///# let device = Device::Cpu;
|
||||
///# let vs = nn::VarStore::new(device);
|
||||
///# let config = Gpt2Config::from_file(config_path);
|
||||
///# let gpt2_model: Gpt2Model = Gpt2Model::new(&vs.root(), &config);
|
||||
/// let (batch_size, sequence_length, past_sequence_length) = (64, 128, 56);
|
||||
/// let input_tensor = Tensor::rand(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let mut past: Vec<Tensor> = Vec::with_capacity(config.n_layer as usize);
|
||||
/// for _ in 0..config.n_layer as usize {
|
||||
/// past.push(Tensor::rand(&[2, batch_size, config.n_head, past_sequence_length, config.n_embd / config.n_head], (Double, device)))
|
||||
/// }
|
||||
/// let attention_mask = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let token_type_ids = Tensor::ones(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, past, hidden_states, attentions) = no_grad(|| {
|
||||
/// gpt2_model
|
||||
/// .forward_t(&Some(input_tensor),
|
||||
/// &Some(past),
|
||||
/// &Some(attention_mask),
|
||||
/// &Some(token_type_ids),
|
||||
/// &Some(position_ids),
|
||||
/// &None,
|
||||
/// false).unwrap()
|
||||
/// });
|
||||
///
|
||||
/// ```
|
||||
///
|
||||
fn forward_t(&self,
|
||||
input_ids: &Option<Tensor>,
|
||||
layer_past: &Option<Vec<Tensor>>,
|
||||
|
@ -1,5 +1,46 @@
|
||||
//! # GPT2 (Radford et al.)
|
||||
//!
|
||||
//! Implementation of the GPT2 language model ([Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) Radford, Wu, Child, Luan, Amodei, Sutskever 2019).
|
||||
//! The base model is implemented in the `gpt2::Gpt2Model` struct. The model also includes a language model head: `gpt2::GPT2LMHeadModel`
|
||||
//! implementing the common `gpt2::LMHeadModel` trait shared between the models used fro generation (see `pipelines` for more information).
|
||||
//!
|
||||
//! # Model set-up and pre-trained weights loading
|
||||
//!
|
||||
//! A full working example is provided in `examples/gpt2.rs`, run with `cargo run --example gpt2`.
|
||||
//! All models expect the following resources:
|
||||
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
|
||||
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
|
||||
//! - `Gpt2Tokenizer` using a `vocab.txt` vocabulary and `merges.txt` 2-gram merges
|
||||
//!
|
||||
//! ```no_run
|
||||
//!# fn main() -> failure::Fallible<()> {
|
||||
//!#
|
||||
//!# let mut home: PathBuf = dirs::home_dir().unwrap();
|
||||
//!# home.push("rustbert");
|
||||
//!# home.push("gpt2");
|
||||
//!# let config_path = &home.as_path().join("config.json");
|
||||
//!# let vocab_path = &home.as_path().join("vocab.txt");
|
||||
//!# let merges_path = &home.as_path().join("merges.txt");
|
||||
//!# let weights_path = &home.as_path().join("model.ot");
|
||||
//! use rust_tokenizers::Gpt2Tokenizer;
|
||||
//! use tch::{nn, Device};
|
||||
//!# use std::path::PathBuf;
|
||||
//! use rust_bert::Config;
|
||||
//! use rust_bert::gpt2::{Gpt2Config, GPT2LMHeadModel};
|
||||
//!
|
||||
//! let device = Device::cuda_if_available();
|
||||
//! let mut vs = nn::VarStore::new(device);
|
||||
//! let tokenizer: Gpt2Tokenizer = Gpt2Tokenizer::from_file(vocab_path.to_str().unwrap(), merges_path.to_str().unwrap(), true);
|
||||
//! let config = Gpt2Config::from_file(config_path);
|
||||
//! let bert_model = GPT2LMHeadModel::new(&vs.root(), &config);
|
||||
//! vs.load(weights_path)?;
|
||||
//!
|
||||
//!# Ok(())
|
||||
//!# }
|
||||
//! ```
|
||||
|
||||
mod gpt2;
|
||||
pub(crate) mod attention;
|
||||
pub(crate) mod transformer;
|
||||
|
||||
pub use gpt2::{Gpt2Config, Gpt2Model, GPT2LMHeadModel, LMHeadModel};
|
||||
pub use gpt2::{Gpt2Config, Gpt2Model, GptActivation, GPT2LMHeadModel, LMHeadModel};
|
@ -73,7 +73,7 @@ impl RobertaForMaskedLM {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "bert"), &config);
|
||||
/// let roberta = RobertaForMaskedLM::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMaskedLM {
|
||||
@ -123,7 +123,7 @@ impl RobertaForMaskedLM {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (output, _, _) = no_grad(|| {
|
||||
/// let (output, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -212,7 +212,7 @@ impl RobertaForSequenceClassification {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "bert"), &config);
|
||||
/// let roberta = RobertaForSequenceClassification::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForSequenceClassification {
|
||||
@ -260,7 +260,7 @@ impl RobertaForSequenceClassification {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (labels, _, _) = no_grad(|| {
|
||||
/// let (labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -321,7 +321,7 @@ impl RobertaForMultipleChoice {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "bert"), &config);
|
||||
/// let roberta = RobertaForMultipleChoice::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForMultipleChoice {
|
||||
@ -369,7 +369,7 @@ impl RobertaForMultipleChoice {
|
||||
/// let token_type_ids = Tensor::zeros(&[num_choices, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[num_choices, sequence_length], true);
|
||||
///
|
||||
/// let (choices, _, _) = no_grad(|| {
|
||||
/// let (choices, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(input_tensor,
|
||||
/// Some(mask),
|
||||
@ -443,7 +443,7 @@ impl RobertaForTokenClassification {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "bert"), &config);
|
||||
/// let roberta = RobertaForTokenClassification::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForTokenClassification {
|
||||
@ -493,7 +493,7 @@ impl RobertaForTokenClassification {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (token_labels, _, _) = no_grad(|| {
|
||||
/// let (token_labels, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
@ -553,7 +553,7 @@ impl RobertaForQuestionAnswering {
|
||||
/// let device = Device::Cpu;
|
||||
/// let p = nn::VarStore::new(device);
|
||||
/// let config = BertConfig::from_file(config_path);
|
||||
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "bert"), &config);
|
||||
/// let roberta = RobertaForQuestionAnswering::new(&(&p.root() / "roberta"), &config);
|
||||
/// ```
|
||||
///
|
||||
pub fn new(p: &nn::Path, config: &BertConfig) -> RobertaForQuestionAnswering {
|
||||
@ -603,7 +603,7 @@ impl RobertaForQuestionAnswering {
|
||||
/// let token_type_ids = Tensor::zeros(&[batch_size, sequence_length], (Int64, device));
|
||||
/// let position_ids = Tensor::arange(sequence_length, (Int64, device)).expand(&[batch_size, sequence_length], true);
|
||||
///
|
||||
/// let (start_scores, end_scores, _, _) = no_grad(|| {
|
||||
/// let (start_scores, end_scores, all_hidden_states, all_attentions) = no_grad(|| {
|
||||
/// roberta_model
|
||||
/// .forward_t(Some(input_tensor),
|
||||
/// Some(mask),
|
||||
|
Loading…
Reference in New Issue
Block a user