Long t5 implementation (#333)

* LongT5 config implementation

* LongT5 WiP: utility functions 1

* LongT5 WiP: utility functions (2)

* LongT5 WiP: utility functions (3)

* LongT5 WiP: utility functions (4)

* made T5 FF activations generic, expose T5 modules to crate

* Longt% local attention WIP

* LongT5 local attention

* LongT5 global attention WIP

* LongT5 global attention

* LongT5 attention modules (WIP)

* align LongT5 position bias with T5

* Addition of LongT5Block

* LongT5Stack WiP

* LongT5Stack implementation

* LongT5Model implementation

* LongT5ForConditionalGeneration implementation

* Addition of LongT5Generator, inclusion in pipelines

* LongT5 attention fixes

* Fix MIN/MAX dtype computation, mask for longt5

* Updated min/max and infinity computation across models

* GlobalTransient attention fixes

* Updated changelog, readme, tests, clippy
This commit is contained in:
guillaume-be 2023-02-12 16:18:20 +00:00 committed by GitHub
parent 84561ec82b
commit d7e9c03694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2444 additions and 76 deletions

View File

@ -130,6 +130,7 @@ jobs:
command: test
args: --package rust-bert
--test sentence_embeddings
--test longt5
convert-model:
name: Model conversion test

View File

@ -2,9 +2,15 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased]
## Added
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
## Fixed
- MIN/MAX computation for float-like (was set to infinity instead of min/max)
## [0.20.0] - 2023-01-21
## Added
- Addition of All-MiniLM-L6-V2 model weights

View File

@ -58,6 +58,7 @@ M2M100| | | |✅ | | | | |
Electra | |✅| | | | |✅| |
ALBERT |✅|✅|✅| | | |✅| ✅ |
T5 | | | |✅ |✅|✅| | ✅ |
LongT5 | | | |✅ |✅|| | |
XLNet|✅|✅|✅|✅ | | |✅| |
Reformer|✅| |✅|✅ | | |✅| |
ProphetNet| | | |✅ |✅ | | | |

View File

@ -16,7 +16,7 @@ use crate::bart::decoder::BartDecoder;
use crate::bart::encoder::BartEncoder;
use crate::common::activations::Activation;
use crate::common::dropout::Dropout;
use crate::common::kind::get_negative_infinity;
use crate::common::kind::get_min;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
@ -273,7 +273,7 @@ pub(crate) fn _make_causal_mask(
let mut mask = Tensor::full(
&[target_length, target_length],
get_negative_infinity(dtype).unwrap(),
get_min(dtype).unwrap(),
(dtype, device),
);
let mask_cond = Tensor::arange(target_length, (dtype, device));
@ -311,10 +311,7 @@ pub(crate) fn _expand_mask(mask: &Tensor, target_length: Option<i64>, dtype: Kin
.expand(&[batch_size, 1, target_length, source_length], true)
.totype(dtype);
let inverted_mask: Tensor = 1 - expanded_mask;
inverted_mask.masked_fill(
&inverted_mask.to_kind(Kind::Bool),
get_negative_infinity(dtype).unwrap(),
)
inverted_mask.masked_fill(&inverted_mask.to_kind(Kind::Bool), get_min(dtype).unwrap())
}
pub(crate) fn _prepare_decoder_attention_mask(

View File

@ -38,3 +38,22 @@ pub(crate) fn get_negative_infinity(kind: Kind) -> Result<Scalar, RustBertError>
}
})
}
pub(crate) fn get_min(kind: Kind) -> Result<Scalar, RustBertError> {
Ok(match kind {
Kind::Uint8 => Scalar::int(u8::MIN.into()),
Kind::Int8 => Scalar::int(i8::MIN.into()),
Kind::Int16 => Scalar::int(i16::MIN.into()),
Kind::Int => Scalar::int(i32::MIN.into()),
Kind::Int64 => Scalar::int(i64::MIN),
Kind::Half => Scalar::float(half::f16::MIN.into()),
Kind::Float => Scalar::float(f32::MIN.into()),
Kind::BFloat16 => Scalar::float(half::bf16::MIN.into()),
Kind::Double => Scalar::float(f64::MIN),
_ => {
return Err(RustBertError::ValueError(format!(
"Type not supported: attempted to get min for {kind:?}",
)))
}
})
}

View File

@ -16,7 +16,7 @@ use crate::bert::{
use crate::common::activations::TensorFunction;
use crate::common::dropout::{Dropout, XDropout};
use crate::common::embeddings::get_shape_and_device_from_ids_embeddings_pair;
use crate::common::kind::get_negative_infinity;
use crate::common::kind::get_min;
use crate::deberta::embeddings::DebertaEmbeddings;
use crate::deberta::encoder::{DebertaEncoder, DebertaEncoderOutput};
use crate::{Activation, Config, RustBertError};
@ -264,7 +264,7 @@ impl Config for DebertaConfig {}
pub fn x_softmax(input: &Tensor, mask: &Tensor, dim: i64) -> Tensor {
let inverse_mask = ((1 - mask) as Tensor).to_kind(Kind::Bool);
input
.masked_fill(&inverse_mask, get_negative_infinity(input.kind()).unwrap())
.masked_fill(&inverse_mask, get_min(input.kind()).unwrap())
.softmax(dim, input.kind())
.masked_fill(&inverse_mask, 0.0)
}

View File

@ -68,6 +68,7 @@
//!Electra | |✅| | | | |✅| |
//!ALBERT |✅|✅|✅| | | |✅| ✅ |
//!T5 | | | |✅ |✅|✅| | ✅ |
//!LongT5 | | | |✅ |✅| | | |
//!XLNet|✅|✅|✅|✅ | | |✅| |
//!Reformer|✅| |✅|✅ | | |✅| |
//!ProphetNet| | | |✅ |✅ | | | |
@ -695,6 +696,8 @@
// These are used abundantly in this code
#![allow(clippy::assign_op_pattern, clippy::upper_case_acronyms)]
extern crate core;
pub mod albert;
pub mod bart;
pub mod bert;
@ -707,6 +710,7 @@ pub mod fnet;
pub mod gpt2;
pub mod gpt_neo;
pub mod longformer;
pub mod longt5;
pub mod m2m_100;
pub mod marian;
pub mod mbart;

807
src/longt5/attention.rs Normal file
View File

@ -0,0 +1,807 @@
// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
// Copyright 2022 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::longt5::layer_norm::LongT5LayerNorm;
use crate::longt5::LongT5Config;
use crate::t5::{
get_relative_position_bucket, LayerState as T5layerState, T5Attention, T5LayerCrossAttention,
};
use std::borrow::Borrow;
use tch::nn::LinearConfig;
use tch::{nn, Device, IndexOp, Kind, Tensor};
pub type LongT5Attention = T5Attention;
pub type LongT5LayerCrossAttention = T5LayerCrossAttention;
pub type LayerState = T5layerState;
fn pad_to_multiple(x: &Tensor, block_length: i64, dim: usize, pad_value: f64) -> Tensor {
let mut x_size = x.size();
let pad_length = (-x_size[dim]).rem_euclid(block_length);
if x_size.iter().any(|&el| el == 0) {
x_size[dim] += pad_length;
Tensor::zeros(x_size.as_slice(), (x.kind(), x.device()))
} else {
let mut pad = vec![0i64; 2 * x.dim()];
pad[2 * dim] = pad_length;
pad.reverse();
x.pad(pad.as_slice(), "constant", pad_value)
}
}
fn split_into_blocks(x: &Tensor, block_length: i64, dim: usize) -> Tensor {
let x_size = x.size();
let padded_x = if x_size[dim] % block_length != 0 {
Some(pad_to_multiple(x, block_length, dim, 0f64))
} else {
None
};
let x = padded_x.as_ref().unwrap_or(x);
let mut x_size = x.size();
let num_blocks = x_size[dim] / block_length;
x_size.remove(dim);
x_size.insert(dim, block_length);
x_size.insert(dim, num_blocks);
if x_size.iter().any(|&el| el == 0) {
Tensor::empty(x_size.as_slice(), (x.kind(), x.device()))
} else {
x.reshape(x_size.as_slice())
}
}
fn concatenate_3_blocks(
x: &Tensor,
block_dim: usize,
sequence_dim: i64,
pad_value: Option<f64>,
) -> Tensor {
let x_size = x.size();
let num_blocks = x_size[block_dim];
let mut pad = vec![0i64; 2 * x.dim()];
pad[2 * block_dim] = 1;
pad[2 * block_dim + 1] = 1;
pad.reverse();
let x = x.pad(pad.as_slice(), "constant", pad_value.unwrap_or(0f64));
let mut block_list: Vec<Tensor> = Vec::with_capacity(3);
for i in 0..3 {
block_list.push(x.narrow(block_dim as i64, i, num_blocks));
}
Tensor::cat(block_list.as_slice(), sequence_dim)
}
fn make_3blocks_relative_position_ids(block_length: i64, device: Device) -> Tensor {
let position_ids = Tensor::arange(3 * block_length, (Kind::Int, device));
let center_position_ids = position_ids.i(block_length..2 * block_length);
position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
}
fn mask_local_attention_mask(local_attention_mask: &Tensor, block_length: i64) -> Tensor {
let relative_position_ids =
make_3blocks_relative_position_ids(block_length, local_attention_mask.device());
let locality_mask = relative_position_ids
.abs()
.lt(block_length)
.unsqueeze(0)
.unsqueeze(0);
local_attention_mask.logical_and(&locality_mask)
}
pub(crate) fn get_local_attention_mask(attention_mask: &Tensor, block_length: i64) -> Tensor {
let blocked_attention_mask = split_into_blocks(attention_mask, block_length, 1);
let three_blocked_attention_mask = concatenate_3_blocks(&blocked_attention_mask, 1, 2, None);
let blocked_attention_mask = blocked_attention_mask.unsqueeze(-1);
let three_blocked_attention_mask = three_blocked_attention_mask.unsqueeze(-2);
let local_attention_mask = mask_local_attention_mask(
&blocked_attention_mask.logical_and(&three_blocked_attention_mask),
block_length,
);
local_attention_mask.unsqueeze(1)
}
fn make_global_fixed_block_ids(
attention_mask: &Tensor,
global_block_size: i64,
) -> (Tensor, Tensor) {
let &[batch_size, seq_length, ..] = attention_mask.size().as_slice() else {unreachable!()};
let handle_orphan_tokens = |block_ids: Tensor| -> Tensor {
let block_ends = Tensor::arange(seq_length, (Kind::Int64, block_ids.device()))
.remainder(global_block_size)
.eq(global_block_size - 1);
let true_block_ends = block_ends.logical_and(&block_ids.ge(0));
let full_blocks = true_block_ends
.sum_dim_intlist([-1].as_slice(), false, block_ids.kind())
.unsqueeze(-1)
- 1;
block_ids.where_self(&block_ids.lt_tensor(&full_blocks), &full_blocks)
};
let fixed_block_mask = attention_mask.ones_like() / global_block_size;
let fixed_block_mask = fixed_block_mask.cumsum(1, fixed_block_mask.kind()) - fixed_block_mask;
let mask = attention_mask
.ones_like()
.where_scalarother(&attention_mask.not_equal(0.0), -1000.0);
let mut global_block_ids = (mask + fixed_block_mask - 1.0).floor();
global_block_ids = global_block_ids.where_scalarother(&global_block_ids.gt(-1.0), -1.0);
global_block_ids = global_block_ids * attention_mask + attention_mask - 1;
global_block_ids = handle_orphan_tokens(global_block_ids);
let num_globals = seq_length / global_block_size;
let sequence_block_ids_max = if num_globals > 0 {
global_block_ids
.max_dim(-1, false)
.0
.repeat(&[num_globals, 1])
.transpose(0, 1)
} else {
Tensor::zeros(
&[batch_size, 0],
(global_block_ids.kind(), global_block_ids.device()),
)
};
let global_segment_ids = Tensor::ones(
&[batch_size, num_globals],
(attention_mask.kind(), attention_mask.device()),
)
.cumsum(-1, attention_mask.kind())
- 1;
let global_segment_ids = global_segment_ids
.ones_like()
.where_scalarother(&global_segment_ids.le_tensor(&sequence_block_ids_max), 0.0);
(
global_block_ids.to_kind(Kind::Int),
global_segment_ids.to_kind(Kind::Int),
)
}
fn make_side_relative_position_ids(attention_mask: &Tensor, global_block_size: i64) -> Tensor {
let (block_ids, global_segment_ids) =
make_global_fixed_block_ids(attention_mask, global_block_size);
let global_seq_length = *global_segment_ids.size().last().unwrap();
let global_positions = Tensor::arange(global_seq_length, (Kind::Int64, block_ids.device()));
global_positions - block_ids.unsqueeze(-1)
}
fn create_global_aggregates(
hidden_states: &Tensor,
block_ids: &Tensor,
global_seq_length: i64,
) -> Tensor {
let block_ids = block_ids.where_scalarother(&block_ids.ge(0), global_seq_length);
let one_hot_block_ids = block_ids
.to_kind(Kind::Int64)
.one_hot(global_seq_length + 1);
let one_hot_block_ids = one_hot_block_ids.narrow(2, 0, one_hot_block_ids.size()[2] - 1);
Tensor::einsum(
"...nd,...ng->...gd",
&[
hidden_states,
&one_hot_block_ids.to_kind(hidden_states.kind()),
],
None,
)
}
fn compute_bias(
block_length: i64,
relative_attention_bias: &nn::Embedding,
is_decoder: bool,
relative_attention_num_buckets: i64,
relative_attention_max_distance: i64,
) -> Tensor {
let device = relative_attention_bias.ws.device();
let memory_position = Tensor::arange(3 * block_length, (Kind::Int64, device));
let context_position = memory_position.narrow(0, block_length, block_length);
let relative_position = memory_position.unsqueeze(0) - context_position.unsqueeze(-1);
let rp_bucket = get_relative_position_bucket(
&relative_position,
!is_decoder,
relative_attention_num_buckets,
relative_attention_max_distance,
);
rp_bucket
.apply(relative_attention_bias)
.permute(&[2, 0, 1])
.unsqueeze(0)
.unsqueeze(0)
}
pub struct LongT5LocalAttention {
is_decoder: bool,
has_relative_attention_bias: bool,
relative_attention_num_buckets: i64,
relative_attention_max_distance: i64,
key_value_proj_dim: i64,
n_heads: i64,
block_length: i64,
dropout: Dropout,
inner_dim: i64,
output_attentions: bool,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
output: nn::Linear,
relative_attention_bias: Option<nn::Embedding>,
}
impl LongT5LocalAttention {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
is_decoder: bool,
has_relative_attention_bias: bool,
) -> LongT5LocalAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let linear_config = LinearConfig {
bias: false,
..Default::default()
};
let block_length = config.local_radius + 1;
let key_value_proj_dim = config.d_kv;
let inner_dim = config.num_heads * config.d_kv;
let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
let dropout = Dropout::new(config.dropout_rate);
let relative_attention_bias = if has_relative_attention_bias {
Some(nn::embedding(
p / "relative_attention_bias",
config.relative_attention_num_buckets,
config.num_heads,
Default::default(),
))
} else {
None
};
LongT5LocalAttention {
is_decoder,
has_relative_attention_bias,
relative_attention_num_buckets: config.relative_attention_num_buckets,
relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
key_value_proj_dim,
n_heads: config.num_heads,
block_length,
dropout,
inner_dim,
output_attentions: config.output_attentions.unwrap_or(false),
query,
key,
value,
output,
relative_attention_bias,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<&Tensor>,
position_bias: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let input_size = hidden_states.size();
let (batch_size, seq_length) = (input_size[0], input_size[1]);
let shape = |states: &Tensor| -> Tensor {
states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
};
let unshape = |states: &Tensor| -> Tensor {
states.contiguous().view([batch_size, -1, self.inner_dim])
};
let query_states = shape(&hidden_states.apply(&self.query));
let key_states = shape(&hidden_states.apply(&self.key));
let value_states = shape(&hidden_states.apply(&self.value));
let query_states = split_into_blocks(&query_states, self.block_length, 1);
let key_states = split_into_blocks(&key_states, self.block_length, 1);
let value_states = split_into_blocks(&value_states, self.block_length, 1);
let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
let mut scores = Tensor::einsum("...qhd,...khd->...hqk", &[query_states, key_states], None);
let calc_position_bias = if position_bias.is_none() {
let mut position_bias = if !self.has_relative_attention_bias {
Tensor::zeros(
&[1, 1, self.n_heads, self.block_length, 3 * self.block_length],
(scores.kind(), scores.device()),
)
} else {
compute_bias(
self.block_length,
self.relative_attention_bias.as_ref().unwrap(),
self.is_decoder,
self.relative_attention_num_buckets,
self.relative_attention_max_distance,
)
};
if let Some(mask) = mask {
let mask = mask.zeros_like().where_scalarother(&mask.gt(0), -1e10);
position_bias = position_bias + mask.transpose(1, 2);
}
Some(position_bias)
} else {
None
};
let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
scores += position_bias;
let attention_weights = scores
.to_kind(Kind::Float)
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train)
.to_kind(value_states.kind());
let attention_output = unshape(&Tensor::einsum(
"...hqk,...khd->...qhd",
&[&attention_weights, &value_states],
None,
))
.narrow(1, 0, seq_length)
.apply(&self.output);
let attention_weights = if self.output_attentions {
Some(attention_weights)
} else {
None
};
let position_bias = if self.has_relative_attention_bias {
calc_position_bias
} else {
None
};
(attention_output, position_bias, attention_weights)
}
}
pub struct LongT5TransientGlobalAttention {
is_decoder: bool,
has_relative_attention_bias: bool,
relative_attention_num_buckets: i64,
relative_attention_max_distance: i64,
key_value_proj_dim: i64,
n_heads: i64,
block_length: i64,
global_block_size: i64,
dropout: Dropout,
inner_dim: i64,
output_attentions: bool,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
output: nn::Linear,
relative_attention_bias: Option<nn::Embedding>,
global_relative_attention_bias: Option<nn::Embedding>,
global_input_layer_norm: LongT5LayerNorm,
}
impl LongT5TransientGlobalAttention {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
is_decoder: bool,
has_relative_attention_bias: bool,
) -> LongT5TransientGlobalAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let linear_config = LinearConfig {
bias: false,
..Default::default()
};
let block_length = config.local_radius + 1;
let global_block_size = config.global_block_size;
let key_value_proj_dim = config.d_kv;
let inner_dim = config.num_heads * config.d_kv;
let key = nn::linear(p / "k", config.d_model, inner_dim, linear_config);
let value = nn::linear(p / "v", config.d_model, inner_dim, linear_config);
let query = nn::linear(p / "q", config.d_model, inner_dim, linear_config);
let output = nn::linear(p / "o", inner_dim, config.d_model, linear_config);
let dropout = Dropout::new(config.dropout_rate);
let global_relative_attention_bias = if has_relative_attention_bias {
Some(nn::embedding(
p / "global_relative_attention_bias",
config.relative_attention_num_buckets,
config.num_heads,
Default::default(),
))
} else {
None
};
let relative_attention_bias = if has_relative_attention_bias {
Some(nn::embedding(
p / "relative_attention_bias",
config.relative_attention_num_buckets,
config.num_heads,
Default::default(),
))
} else {
None
};
let global_input_layer_norm = LongT5LayerNorm::new(
p / "global_input_layer_norm",
config.d_model,
config.layer_norm_epsilon,
);
LongT5TransientGlobalAttention {
is_decoder,
has_relative_attention_bias,
relative_attention_num_buckets: config.relative_attention_num_buckets,
relative_attention_max_distance: config.relative_attention_max_distance.unwrap_or(128),
key_value_proj_dim,
n_heads: config.num_heads,
block_length,
global_block_size,
dropout,
inner_dim,
output_attentions: config.output_attentions.unwrap_or(false),
query,
key,
value,
output,
relative_attention_bias,
global_relative_attention_bias,
global_input_layer_norm,
}
}
fn compute_side_bias(&self, mask: &Tensor, global_segment_ids: &Tensor) -> Tensor {
let side_attention_mask = mask
.unsqueeze(-1)
.eq_tensor(&global_segment_ids.unsqueeze(1))
.unsqueeze(1);
let attention_side_bias = side_attention_mask
.zeros_like()
.where_scalarother(&side_attention_mask.gt(0), -1e10);
let side_relative_position = make_side_relative_position_ids(mask, self.global_block_size);
let side_relative_position_bucket = get_relative_position_bucket(
&side_relative_position,
!self.is_decoder,
self.relative_attention_num_buckets,
self.relative_attention_max_distance,
);
let side_bias = side_relative_position_bucket
.apply(self.global_relative_attention_bias.as_ref().unwrap())
.permute(&[0, 3, 1, 2]);
attention_side_bias + side_bias
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
mask: Option<&Tensor>,
position_bias: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let input_size = hidden_states.size();
let (batch_size, seq_length) = (input_size[0], input_size[1]);
let shape = |states: &Tensor| -> Tensor {
states.view([batch_size, -1, self.n_heads, self.key_value_proj_dim])
};
let unshape = |states: &Tensor| -> Tensor {
states.contiguous().view([batch_size, -1, self.inner_dim])
};
let calc_mask = if mask.is_none() {
let mut mask_size = input_size;
let _ = mask_size.pop();
Some(Tensor::ones(
mask_size.as_slice(),
(Kind::Bool, hidden_states.device()),
))
} else {
None
};
let (block_ids, global_segment_ids) = make_global_fixed_block_ids(
mask.unwrap_or_else(|| calc_mask.as_ref().unwrap()),
self.global_block_size,
);
let global_seq_length = *global_segment_ids.size().last().unwrap();
let global_inputs = create_global_aggregates(hidden_states, &block_ids, global_seq_length)
.apply(&self.global_input_layer_norm);
let query_states = shape(&hidden_states.apply(&self.query));
let key_states = shape(&hidden_states.apply(&self.key));
let value_states = shape(&hidden_states.apply(&self.value));
let side_key_states = shape(&global_inputs.apply(&self.key));
let side_value_states = shape(&global_inputs.apply(&self.value));
let query_states = split_into_blocks(&query_states, self.block_length, 1);
let key_states = split_into_blocks(&key_states, self.block_length, 1);
let value_states = split_into_blocks(&value_states, self.block_length, 1);
let key_states = concatenate_3_blocks(&key_states, 1, 2, None);
let value_states = concatenate_3_blocks(&value_states, 1, 2, None);
let mut reps = vec![1; side_key_states.dim() + 1];
reps[1] = key_states.size()[1];
let side_key_states = side_key_states.unsqueeze(1).repeat(reps.as_slice());
let side_value_states = side_value_states.unsqueeze(1).repeat(reps.as_slice());
let key_states = Tensor::cat(&[key_states, side_key_states], 2);
let value_states = Tensor::cat(&[value_states, side_value_states], 2);
let mut scores = Tensor::einsum("...qhd,...khd->...hqk", &[query_states, key_states], None);
let local_attention_mask = mask.map(|mask| {
let local_attention_mask = get_local_attention_mask(mask, self.block_length);
local_attention_mask
.zeros_like()
.where_scalarother(&local_attention_mask.gt(0), -1e10)
});
let calc_position_bias = if position_bias.is_none() {
let mut position_bias = if !self.has_relative_attention_bias {
Tensor::zeros(
&[1, 1, self.n_heads, self.block_length, 3 * self.block_length],
(scores.kind(), scores.device()),
)
} else {
compute_bias(
self.block_length,
self.relative_attention_bias.as_ref().unwrap(),
self.is_decoder,
self.relative_attention_num_buckets,
self.relative_attention_max_distance,
)
};
if let Some(local_attention_mask) = local_attention_mask {
position_bias = position_bias + local_attention_mask.transpose(1, 2);
}
let calc_mask = if mask.is_none() {
Some(Tensor::ones(
&[batch_size, seq_length],
(global_segment_ids.kind(), global_segment_ids.device()),
))
} else {
None
};
let mask = mask.unwrap_or_else(|| calc_mask.as_ref().unwrap());
let side_position_bias = self.compute_side_bias(mask, &global_segment_ids);
let side_position_bias = split_into_blocks(
&side_position_bias,
self.block_length,
side_position_bias.dim() - 2,
)
.transpose(1, 2);
let position_bias = Tensor::cat(&[position_bias, side_position_bias], -1);
Some(position_bias)
} else {
None
};
let position_bias = position_bias.unwrap_or_else(|| calc_position_bias.as_ref().unwrap());
scores += position_bias;
let attention_weights = scores
.to_kind(Kind::Float)
.softmax(-1, scores.kind())
.apply_t(&self.dropout, train);
let attention_output = unshape(&Tensor::einsum(
"...hqk,...khd->...qhd",
&[&attention_weights, &value_states],
None,
))
.narrow(1, 0, seq_length)
.apply(&self.output);
let attention_weights = if self.output_attentions {
Some(attention_weights)
} else {
None
};
let position_bias = if self.has_relative_attention_bias {
calc_position_bias
} else {
None
};
(attention_output, position_bias, attention_weights)
}
}
pub struct LongT5LayerSelfAttention {
self_attention: LongT5Attention,
layer_norm: LongT5LayerNorm,
dropout: Dropout,
}
impl LongT5LayerSelfAttention {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
has_relative_attention_bias: bool,
is_decoder: bool,
store_cache: bool,
output_attentions: bool,
) -> LongT5LayerSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let self_attention = LongT5Attention::new(
p / "SelfAttention",
&config.into(),
is_decoder,
!is_decoder,
store_cache,
output_attentions,
has_relative_attention_bias,
);
let layer_norm =
LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
let dropout = Dropout::new(config.dropout_rate);
LongT5LayerSelfAttention {
self_attention,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
position_bias: Option<&Tensor>,
attention_mask: Option<&Tensor>,
layer_state: Option<LayerState>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<LayerState>) {
let norm_x = hidden_states.apply(&self.layer_norm);
let (y, attention_weights, position_bias, layer_state) = self.self_attention.forward_t(
&norm_x,
None,
position_bias,
attention_mask,
layer_state,
None,
train,
);
let output = hidden_states + y.apply_t(&self.dropout, train);
(output, attention_weights, position_bias, layer_state)
}
}
pub struct LongT5LayerLocalSelfAttention {
local_self_attention: LongT5LocalAttention,
layer_norm: LongT5LayerNorm,
dropout: Dropout,
}
impl LongT5LayerLocalSelfAttention {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
has_relative_attention_bias: bool,
is_decoder: bool,
) -> LongT5LayerLocalSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let local_self_attention = LongT5LocalAttention::new(
p / "LocalSelfAttention",
config,
is_decoder,
has_relative_attention_bias,
);
let layer_norm =
LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
let dropout = Dropout::new(config.dropout_rate);
LongT5LayerLocalSelfAttention {
local_self_attention,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
position_bias: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let normed_hidden_states = hidden_states.apply(&self.layer_norm);
let (attention_output, position_bias, attention_weights) = self
.local_self_attention
.forward_t(&normed_hidden_states, attention_mask, position_bias, train);
let output = hidden_states + attention_output.apply_t(&self.dropout, train);
(output, position_bias, attention_weights)
}
}
pub struct LongT5LayerTransientGlobalSelfAttention {
transient_global_sef_attention: LongT5TransientGlobalAttention,
layer_norm: LongT5LayerNorm,
dropout: Dropout,
}
impl LongT5LayerTransientGlobalSelfAttention {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
has_relative_attention_bias: bool,
is_decoder: bool,
) -> LongT5LayerTransientGlobalSelfAttention
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let transient_global_sef_attention = LongT5TransientGlobalAttention::new(
p / "TransientGlobalSelfAttention",
config,
is_decoder,
has_relative_attention_bias,
);
let layer_norm =
LongT5LayerNorm::new(p / "layer_norm", config.d_model, config.layer_norm_epsilon);
let dropout = Dropout::new(config.dropout_rate);
LongT5LayerTransientGlobalSelfAttention {
transient_global_sef_attention,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
position_bias: Option<&Tensor>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>) {
let normed_hidden_states = hidden_states.apply(&self.layer_norm);
let (attention_output, position_bias, attention_weights) = self
.transient_global_sef_attention
.forward_t(&normed_hidden_states, attention_mask, position_bias, train);
let output = hidden_states + attention_output.apply_t(&self.dropout, train);
(output, position_bias, attention_weights)
}
}

457
src/longt5/encoder.rs Normal file
View File

@ -0,0 +1,457 @@
// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
// Copyright 2022 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::kind::get_min;
use crate::longt5::attention::{
get_local_attention_mask, LayerState, LongT5LayerCrossAttention, LongT5LayerLocalSelfAttention,
LongT5LayerSelfAttention, LongT5LayerTransientGlobalSelfAttention,
};
use crate::longt5::layer_norm::LongT5LayerNorm;
use crate::longt5::longt5_model::EncoderAttentionType;
use crate::longt5::LongT5Config;
use crate::t5::{T5Block, T5BlockOutput, T5LayerFF, T5StackOutput};
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::{nn, Kind, Tensor};
pub type LongT5LayerFF = T5LayerFF;
enum LongT5AttentionLayer {
SelfAttention(LongT5LayerSelfAttention),
Local(LongT5LayerLocalSelfAttention),
Global(LongT5LayerTransientGlobalSelfAttention),
}
impl LongT5AttentionLayer {
pub fn forward_t(
&self,
hidden_states: &Tensor,
position_bias: Option<&Tensor>,
attention_mask: Option<&Tensor>,
layer_state: Option<LayerState>,
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<LayerState>) {
match self {
LongT5AttentionLayer::SelfAttention(ref layer) => layer.forward_t(
hidden_states,
position_bias,
attention_mask,
layer_state,
train,
),
LongT5AttentionLayer::Local(ref layer) => {
let (output, position_bias, attention_weights) =
layer.forward_t(hidden_states, attention_mask, position_bias, train);
(output, attention_weights, position_bias, None)
}
LongT5AttentionLayer::Global(ref layer) => {
let (output, position_bias, attention_weights) =
layer.forward_t(hidden_states, attention_mask, position_bias, train);
(output, attention_weights, position_bias, None)
}
}
}
}
pub struct LongT5Block {
attention_layer: LongT5AttentionLayer,
cross_attention: Option<LongT5LayerCrossAttention>,
ff_layer: LongT5LayerFF,
}
impl LongT5Block {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
has_relative_attention_bias: bool,
is_decoder: bool,
store_cache: bool,
output_attentions: bool,
) -> LongT5Block
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow() / "layer";
let mut module_index = 0;
let attention_layer = if is_decoder {
LongT5AttentionLayer::SelfAttention(LongT5LayerSelfAttention::new(
&p / module_index,
config,
has_relative_attention_bias,
is_decoder,
store_cache,
output_attentions,
))
} else {
match config.encoder_attention_type {
Some(EncoderAttentionType::Local) | None => {
LongT5AttentionLayer::Local(LongT5LayerLocalSelfAttention::new(
&p / module_index,
config,
has_relative_attention_bias,
is_decoder,
))
}
Some(EncoderAttentionType::TransientGlobal) => {
LongT5AttentionLayer::Global(LongT5LayerTransientGlobalSelfAttention::new(
&p / module_index,
config,
has_relative_attention_bias,
is_decoder,
))
}
}
};
let cross_attention = if is_decoder {
module_index += 1;
Some(LongT5LayerCrossAttention::new(
&p / module_index,
&config.into(),
false,
is_decoder,
store_cache,
output_attentions,
))
} else {
None
};
module_index += 1;
let ff_layer = LongT5LayerFF::new(&p / module_index, &config.into());
LongT5Block {
attention_layer,
cross_attention,
ff_layer,
}
}
pub fn forward_t(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
position_bias: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
encoder_decoder_position_bias: Option<&Tensor>,
mut layer_states: (Option<LayerState>, Option<LayerState>),
train: bool,
) -> LongT5BlockOutput {
let (
mut hidden_states,
self_attention_weights,
self_attention_position_bias,
self_attention_layer_past,
) = self.attention_layer.forward_t(
hidden_states,
position_bias,
attention_mask,
layer_states.0,
train,
);
hidden_states = T5Block::clamp_hidden_states(hidden_states);
let (
mut hidden_states,
cross_attention_weights,
cross_attention_position_bias,
cross_attention_layer_past,
) = if self.cross_attention.is_some() & encoder_hidden_states.is_some() {
let query_length = self_attention_layer_past
.as_ref()
.map(|value| value.prev_key.size()[2]);
self.cross_attention.as_ref().unwrap().forward_t(
&hidden_states,
encoder_hidden_states,
encoder_decoder_position_bias,
encoder_attention_mask,
layer_states.1,
query_length,
train,
)
} else {
(hidden_states, None, None, None)
};
hidden_states = T5Block::clamp_hidden_states(hidden_states);
layer_states = (self_attention_layer_past, cross_attention_layer_past);
let mut hidden_states = self.ff_layer.forward_t(&hidden_states, train);
hidden_states = T5Block::clamp_hidden_states(hidden_states);
LongT5BlockOutput {
hidden_states,
self_attention_weights,
cross_attention_weights,
self_attention_position_bias,
cross_attention_position_bias,
cache: layer_states,
}
}
}
pub struct LongT5Stack {
blocks: Vec<LongT5Block>,
final_layer_norm: LongT5LayerNorm,
dropout: Dropout,
output_attentions: bool,
output_hidden_states: bool,
is_decoder: bool,
store_cache: bool,
encoder_attention_type: EncoderAttentionType,
block_length: i64,
}
impl LongT5Stack {
pub fn new<'p, P>(
p: P,
config: &LongT5Config,
is_decoder: bool,
store_cache: bool,
output_attentions: bool,
output_hidden_states: bool,
) -> LongT5Stack
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let dropout = Dropout::new(config.dropout_rate);
let mut blocks: Vec<LongT5Block> = vec![];
let p_layers = p / "block";
for layer_index in 0..config.num_layers {
blocks.push(LongT5Block::new(
&p_layers / layer_index,
config,
layer_index == 0,
is_decoder,
store_cache,
output_attentions,
));
}
let final_layer_norm = LongT5LayerNorm::new(
p / "final_layer_norm",
config.d_model,
config.layer_norm_epsilon,
);
let encoder_attention_type = config
.encoder_attention_type
.unwrap_or(EncoderAttentionType::Local);
let block_length = config.local_radius + 1;
LongT5Stack {
blocks,
final_layer_norm,
dropout,
output_attentions,
output_hidden_states,
is_decoder,
store_cache,
encoder_attention_type,
block_length,
}
}
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
encoder_attention_mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
embeddings: &nn::Embedding,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> Result<LongT5StackOutput, RustBertError> {
let (calc_input_embeddings, input_shape, _) =
process_ids_embeddings_pair(input_ids, input_embeds, embeddings)?;
let input_embeddings =
input_embeds.unwrap_or_else(|| calc_input_embeddings.as_ref().unwrap());
let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);
let mask_seq_length = if old_layer_states.is_some() {
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
old_layer_states.as_ref().unwrap()[0]
.0
.as_ref()
.unwrap()
.prev_key
.size()[2]
+ sequence_length
} else {
sequence_length
}
} else {
sequence_length
};
let calculated_attention_mask = if attention_mask.is_none() {
Some(Tensor::ones(
&[batch_size, mask_seq_length],
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let attention_mask =
attention_mask.unwrap_or_else(|| calculated_attention_mask.as_ref().unwrap());
let extended_attention_mask = if self.is_decoder {
let extended_attention_mask = match attention_mask.dim() {
3 => attention_mask.unsqueeze(1),
2 => {
if self.is_decoder {
let seq_ids = Tensor::arange(
sequence_length,
(input_embeddings.kind(), input_embeddings.device()),
);
let causal_mask = seq_ids.unsqueeze(0).unsqueeze(0).repeat(&[
batch_size,
sequence_length,
1,
]);
let causal_mask =
causal_mask.le_tensor(&seq_ids.unsqueeze(0).unsqueeze(-1));
causal_mask.unsqueeze(1) * attention_mask.unsqueeze(1).unsqueeze(1)
} else {
attention_mask.unsqueeze(1).unsqueeze(1)
}
}
_ => {
return Err(RustBertError::ValueError(
"Invalid attention mask dimension, must be 2 or 3".into(),
));
}
};
Some(
(extended_attention_mask.ones_like() - extended_attention_mask)
.to_kind(input_embeddings.kind())
* get_min(input_embeddings.kind()).unwrap(),
)
} else if let EncoderAttentionType::Local = self.encoder_attention_type {
Some(get_local_attention_mask(attention_mask, self.block_length))
} else {
None
};
let extended_attention_mask = extended_attention_mask.as_ref().unwrap_or(attention_mask);
let encoder_extended_attention_mask = if self.is_decoder & encoder_hidden_states.is_some() {
let new_shape = &encoder_hidden_states.as_ref().unwrap().size()[..2];
let calculated_encoder_attention_mask = if encoder_attention_mask.is_none() {
Some(Tensor::ones(
&[batch_size, new_shape[1]],
(Kind::Int64, input_embeddings.device()),
))
} else {
None
};
let encoder_attention_mask = encoder_attention_mask
.unwrap_or_else(|| calculated_encoder_attention_mask.as_ref().unwrap());
let mut encoder_extended_attention_mask =
encoder_attention_mask.to_kind(input_embeddings.kind());
if encoder_extended_attention_mask.dim() == 3 {
encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze_(1);
} else if encoder_extended_attention_mask.dim() == 2 {
encoder_extended_attention_mask =
encoder_extended_attention_mask.unsqueeze_(1).unsqueeze_(1);
};
Some(
(encoder_extended_attention_mask.ones_like() - encoder_extended_attention_mask)
* get_min(input_embeddings.kind()).unwrap(),
)
} else {
None
};
let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states {
Some(Vec::with_capacity(self.blocks.len()))
} else {
None
};
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions {
Some(Vec::with_capacity(self.blocks.len()))
} else {
None
};
let mut next_cache: Option<Vec<(Option<LayerState>, Option<LayerState>)>> =
if self.store_cache {
if old_layer_states.is_some() {
old_layer_states
} else {
Some(vec![(None, None); self.blocks.len()])
}
} else {
None
};
let mut position_bias = None;
let mut encoder_decoder_position_bias = None;
let mut attention_weights: Option<Tensor>;
let mut hidden_state = input_embeddings.apply_t(&self.dropout, train);
for (layer_idx, layer) in self.blocks.iter().enumerate() {
let layer_state = match &mut next_cache {
Some(values) => std::mem::take(&mut values[layer_idx]),
None => (None, None),
};
let block_output = layer.forward_t(
&hidden_state,
Some(extended_attention_mask),
position_bias.as_ref(),
encoder_hidden_states,
encoder_extended_attention_mask.as_ref(),
encoder_decoder_position_bias.as_ref(),
layer_state,
train,
);
if layer_idx == 0 {
position_bias = block_output.self_attention_position_bias;
encoder_decoder_position_bias = block_output.cross_attention_position_bias;
}
hidden_state = block_output.hidden_states;
attention_weights = block_output.cross_attention_weights;
if let Some(hidden_states) = all_hidden_states.borrow_mut() {
hidden_states.push(hidden_state.as_ref().copy().transpose(0, 1));
};
if let Some(attentions) = all_attentions.borrow_mut() {
attentions.push(std::mem::take(&mut attention_weights.unwrap()));
};
if let Some(value) = &mut next_cache {
value[layer_idx] = block_output.cache
};
}
let hidden_state = hidden_state
.apply(&self.final_layer_norm)
.apply_t(&self.dropout, train);
Ok(LongT5StackOutput {
hidden_state,
all_hidden_states,
all_attentions,
next_cache,
})
}
}
pub type LongT5BlockOutput = T5BlockOutput;
pub type LongT5StackOutput = T5StackOutput;

15
src/longt5/layer_norm.rs Normal file
View File

@ -0,0 +1,15 @@
// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
// Copyright 2022 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::t5::T5LayerNorm;
pub type LongT5LayerNorm = T5LayerNorm;

894
src/longt5/longt5_model.rs Normal file
View File

@ -0,0 +1,894 @@
// Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
// Copyright 2022 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::longt5::encoder::LongT5Stack;
use crate::longt5::LayerState;
use crate::pipelines::common::{ModelType, TokenizerOption};
use crate::pipelines::generation_utils::private_generation_utils::{
PreparedInput, PrivateLanguageGenerator,
};
use crate::pipelines::generation_utils::{
Cache, GenerateConfig, LMHeadModel, LMModelOutput, LanguageGenerator,
};
use crate::t5::{FeedForwardProj, T5Config, T5ModelOutput, TaskSpecificParams};
use crate::{Config, RustBertError};
use rust_tokenizers::tokenizer::{T5Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::T5Vocab;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use tch::nn::{embedding, LinearConfig};
use tch::{nn, Tensor};
/// # LongT5 Pretrained model weight files
pub struct LongT5ModelResources;
/// # LongT5 Pretrained model config files
pub struct LongT5ConfigResources;
/// # LongT5 Pretrained model vocab files
pub struct LongT5VocabResources;
impl LongT5ModelResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary>. Modified with conversion to C-array format.
pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = (
"longt5-tglobal-base-book-summary/model",
"https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/rust_model.ot",
);
}
impl LongT5ConfigResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary>. Modified with conversion to C-array format.
pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = (
"longt5-tglobal-base-book-summary/config",
"https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/config.json",
);
}
impl LongT5VocabResources {
/// Shared under Apache 2.0 license at <https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary>. Modified with conversion to C-array format.
pub const TGLOBAL_BASE_BOOK_SUMMARY: (&'static str, &'static str) = (
"longt5-tglobal-base-book-summary/spiece",
"https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary/resolve/main/spiece.model",
);
}
#[derive(Clone, Debug, Serialize, Deserialize, Copy)]
#[serde(rename_all = "kebab-case")]
/// # Options for LongT5 encoder attention type
pub enum EncoderAttentionType {
/// Local
Local,
/// Transient Global
TransientGlobal,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
/// # LongT5 model configuration
/// Defines the LongT5 model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct LongT5Config {
pub dropout_rate: f64,
pub d_model: i64,
pub d_ff: i64,
pub d_kv: i64,
pub decoder_start_token_id: Option<i64>,
pub bos_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub initializer_factor: f64,
pub is_encoder_decoder: Option<bool>,
pub layer_norm_epsilon: f64,
pub num_heads: i64,
pub num_layers: i64,
pub num_decoder_layers: Option<i64>,
pub local_radius: i64,
pub global_block_size: i64,
pub output_past: Option<bool>,
pub pad_token_id: Option<i64>,
pub relative_attention_num_buckets: i64,
pub relative_attention_max_distance: Option<i64>,
pub encoder_attention_type: Option<EncoderAttentionType>,
pub vocab_size: i64,
pub feed_forward_proj: Option<FeedForwardProj>,
pub tie_word_embeddings: Option<bool>,
pub task_specific_params: Option<TaskSpecificParams>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}
impl Config for LongT5Config {}
impl Default for LongT5Config {
fn default() -> Self {
LongT5Config {
dropout_rate: 0.1,
d_model: 512,
d_ff: 2048,
d_kv: 64,
decoder_start_token_id: None,
bos_token_id: None,
eos_token_id: Some(1),
initializer_factor: 1.0,
is_encoder_decoder: None,
layer_norm_epsilon: 1e-6,
num_heads: 8,
num_layers: 6,
num_decoder_layers: None,
local_radius: 127,
global_block_size: 16,
output_past: None,
pad_token_id: Some(0),
relative_attention_num_buckets: 32,
relative_attention_max_distance: Some(128),
encoder_attention_type: Some(EncoderAttentionType::Local),
vocab_size: 32128,
feed_forward_proj: Some(FeedForwardProj::Relu),
tie_word_embeddings: None,
task_specific_params: None,
output_attentions: None,
output_hidden_states: None,
}
}
}
impl From<&LongT5Config> for T5Config {
fn from(val: &LongT5Config) -> T5Config {
T5Config {
dropout_rate: val.dropout_rate,
d_model: val.d_model,
d_ff: val.d_ff,
d_kv: val.d_kv,
decoder_start_token_id: val.decoder_start_token_id,
bos_token_id: None,
eos_token_id: val.eos_token_id,
initializer_factor: val.initializer_factor,
is_encoder_decoder: val.is_encoder_decoder,
layer_norm_epsilon: val.layer_norm_epsilon,
num_heads: val.num_heads,
num_layers: val.num_layers,
output_past: val.output_past,
pad_token_id: val.pad_token_id,
relative_attention_num_buckets: val.relative_attention_num_buckets,
relative_attention_max_distance: val.relative_attention_max_distance,
vocab_size: val.vocab_size,
feed_forward_proj: val.feed_forward_proj,
tie_word_embeddings: val.tie_word_embeddings,
task_specific_params: val.task_specific_params.clone(),
output_attentions: val.output_attentions,
output_hidden_states: val.output_hidden_states,
}
}
}
/// # LongT5 Base model
/// Base architecture for LongT5 model. Usually complemented with a task-specific head, such as a language model head.
/// It is made of the following blocks:
/// - `encoder`: `T5Stack` (transformer) made of a vector of encoding layers
/// - `decoder`: `T5Stack` (transformer) made of a vector of decoding layers with self attention and encoder cross-attention.
/// caching is implemented for the decoder to avoid recalculating static states (encoder key/values and previously calculated decoder key/values)
/// - `embeddings`: `nn::Embedding` Shared embeddings for the encoder and decoder.
pub struct LongT5Model {
pub(crate) encoder: LongT5Stack,
decoder: LongT5Stack,
pub(crate) embeddings: nn::Embedding,
}
impl LongT5Model {
/// Build a new `LongT5Model`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the LongT5 model
/// * `config` - `LongT5Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::longt5::{LongT5Config, LongT5Model};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = LongT5Config::from_file(config_path);
/// let long_t5: LongT5Model = LongT5Model::new(&p.root() / "longt5", &config);
/// ```
pub fn new<'p, P>(p: P, config: &LongT5Config) -> LongT5Model
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let embeddings: nn::Embedding = embedding(
p / "shared",
config.vocab_size,
config.d_model,
Default::default(),
);
let encoder = LongT5Stack::new(
p / "encoder",
config,
false,
false,
config.output_attentions.unwrap_or(false),
config.output_hidden_states.unwrap_or(false),
);
let decoder = LongT5Stack::new(
p / "decoder",
config,
true,
true,
config.output_attentions.unwrap_or(false),
config.output_hidden_states.unwrap_or(false),
);
LongT5Model {
encoder,
decoder,
embeddings,
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `LongT5ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *hidden_size*) representing the activations of the last decoder hidden state
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::longt5::{LongT5Config, LongT5Model};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = LongT5Config::from_file(config_path);
/// # let longt5_model: LongT5Model = LongT5Model::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// longt5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// None,
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
decoder_input_embeds: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> Result<LongT5ModelOutput, RustBertError> {
let (calc_hidden_states, all_encoder_hidden_states, all_encoder_attentions) =
if encoder_outputs.is_none() {
let encoder_output = self.encoder.forward_t(
input_ids,
attention_mask,
None,
None,
input_embeds,
&self.embeddings,
None,
train,
)?;
(
Some(encoder_output.hidden_state),
encoder_output.all_hidden_states,
encoder_output.all_attentions,
)
} else {
(None, None, None)
};
let encoder_output =
encoder_outputs.unwrap_or_else(|| calc_hidden_states.as_ref().unwrap());
let decoder_output = self
.decoder
.forward_t(
decoder_input_ids,
decoder_attention_mask,
Some(encoder_output),
attention_mask,
decoder_input_embeds,
&self.embeddings,
old_layer_states,
train,
)
.unwrap();
Ok(LongT5ModelOutput {
decoder_output: decoder_output.hidden_state,
encoder_hidden_state: calc_hidden_states,
next_cache: decoder_output.next_cache,
all_decoder_hidden_states: decoder_output.all_hidden_states,
all_decoder_attentions: decoder_output.all_attentions,
all_encoder_hidden_states,
all_encoder_attentions,
})
}
}
/// # LongT5 Model for conditional generation
/// LongT5 model with a vocabulary decoding head
/// It is made of the following blocks:
/// - `base_model`: `LongT5Model` Base LongT5 model
/// - `model_dim`: `f64` representation of the model dimension for scaling of the generated logits
pub struct LongT5ForConditionalGeneration {
base_model: LongT5Model,
model_dim: f64,
tie_word_embeddings: bool,
lm_head: Option<nn::Linear>,
}
impl LongT5ForConditionalGeneration {
/// Build a new `LongT5ForConditionalGeneration`
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BART model
/// * `config` - `LongT5Config` object defining the model architecture
///
/// # Example
///
/// ```no_run
/// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration};
/// use rust_bert::Config;
/// use std::path::Path;
/// use tch::{nn, Device};
///
/// let config_path = Path::new("path/to/config.json");
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = LongT5Config::from_file(config_path);
/// let longt5 = LongT5ForConditionalGeneration::new(&p.root() / "t5", &config);
/// ```
pub fn new<'p, P>(p: P, config: &LongT5Config) -> LongT5ForConditionalGeneration
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let base_model = LongT5Model::new(p, config);
let tie_word_embeddings = config.tie_word_embeddings.unwrap_or(true);
let lm_head = if !tie_word_embeddings {
Some(nn::linear(
p / "lm_head",
config.d_model,
config.vocab_size,
LinearConfig {
bias: false,
..Default::default()
},
))
} else {
None
};
LongT5ForConditionalGeneration {
base_model,
model_dim: config.d_model as f64,
tie_word_embeddings,
lm_head,
}
}
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *source_sequence_length*). This or `input_embeds` must be provided.
/// * `attention_mask` - Optional attention mask of shape (*batch size*, *source_sequence_length*) for the encoder positions. Positions with a mask with value 0 will be masked.
/// * `encoder_outputs` - Optional tuple made of a tensor of shape (*batch size*, *source_sequence_length*, *encoder_hidden_dim*) and optional vectors of tensors of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*).
/// These correspond to the encoder last hidden state and optional hidden states/attention weights for encoder layers. When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*). This or `decoder_input_embeds` must be provided.
/// * `decoder_attention_mask` - Optional attention mask of shape (*batch size*, *target_sequence_length*) for the decoder positions. Positions with a mask with value 0 will be masked.
/// * `input_embeds` - Optional input tensor of shape (*batch size*, *source_sequence_length*, *embeddings dimension*). This or `input_ids` must be provided.
/// * `decoder_input_embeds` - Optional input tensor of shape (*batch size*, *target_sequence_length*, *embeddings dimension*). This or `decoder_input_ids` must be provided.
/// * `old_layer_states` - Optional vector of length `num_layers` containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `longT5ModelOutput` containing:
/// - `decoder_output` - `Tensor` of shape (*batch size*, *target_sequence_length*, *vocab_size*) representing the logits for each sequence position and vocabulary item
/// - `encoder_hidden_states` - `Tensor` of shape (*batch size*, *source_sequence_length*, *hidden_size*) representing the activations of the last encoder hidden state
/// - `cache` - `Option<Vec<(Option<Vec<LayerState, LayerState>>)>>` of length *n_layer* containing the encoder padding mask and past keys and values for both the self attention and the encoder cross attention of each layer of the decoder.
/// - `all_encoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_encoder_attentions` - `Option<Vec<Tensor>>` of length *num_encoder_layers* with shape (*batch size*, *source_sequence_length*, *hidden_size*)
/// - `all_decoder_hidden_states` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
/// - `all_decoder_attentions` - `Option<Vec<Tensor>>` of length *num_decoder_layers* with shape (*batch size*, *target_sequence_length*, *hidden_size*)
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = LongT5Config::from_file(config_path);
/// # let longt5_model: LongT5ForConditionalGeneration = LongT5ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// longt5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// None,
/// None,
/// false,
/// )
/// });
/// ```
pub fn forward_t(
&self,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
decoder_attention_mask: Option<&Tensor>,
input_embeds: Option<&Tensor>,
decoder_input_embeds: Option<&Tensor>,
old_layer_states: Option<Vec<(Option<LayerState>, Option<LayerState>)>>,
train: bool,
) -> Result<LongT5ModelOutput, RustBertError> {
let base_model_output = self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
decoder_attention_mask,
input_embeds,
decoder_input_embeds,
old_layer_states,
train,
)?;
let lm_logits = if self.tie_word_embeddings {
base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5))
} else {
base_model_output
.decoder_output
.apply(self.lm_head.as_ref().unwrap())
};
Ok(T5ModelOutput {
decoder_output: lm_logits,
..base_model_output
})
}
pub fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Tensor {
self.base_model
.encoder
.forward_t(
Some(input_ids),
attention_mask,
None,
None,
None,
&self.base_model.embeddings,
None,
false,
)
.unwrap()
.hidden_state
}
}
impl LMHeadModel for LongT5ForConditionalGeneration {
/// Forward pass through the model
///
/// # Arguments
///
/// * `input_ids` - Optional input tensor of shape (*batch size*, *sequence_length*). If None, pre-computed embeddings must be provided (see `input_embeds`)
/// * `cache` - `Cache` object containing tuples of optional `LayerStates` containing the last calculated key and value pairs for the decoder. This avoids recomputing attention weights at past positions and speeds up decoding.
/// * `attention_mask` - Optional mask of shape (*batch size*, *sequence_length*). Masked position have value 0, non-masked value 1. If None set to 1
/// * `input_embeds` - Unused for LongT5
/// * `token_type_ids` - Unused for LongT5
/// * `position_ids` - Unused for LongT5
/// * `encoder_outputs` - Optional tensor of shape (*batch size*, *source_sequence_length*, *hidden_size*). When provided, the encoder hidden state will not be recalculated. Useful for generation tasks.
/// * `decoder_input_ids` - Optional input tensor of shape (*batch size*, *target_sequence_length*).
/// * `train` - boolean flag to turn on/off the dropout layers in the model. Should be set to false for inference.
///
/// # Returns
///
/// * `LMModelOutput` containing:
/// - `lm_logits` - `Tensor` of shape (*batch size*, *sequence_length*, *vocab_size*) representing the logits for each vocab item and position
/// - `cache` - `T5Cache` made of `Option<Vec<(Option<Vec<&LayerState, &LayerState>>)>>` of length *n_layer* containing the encoder past keys and values for
/// both the self attention and the encoder cross attention of each layer of the decoder.
///
/// # Example
///
/// ```no_run
/// # use tch::{nn, Device, Tensor, no_grad};
/// # use rust_bert::Config;
/// # use std::path::Path;
/// # use tch::kind::Kind::{Int64, Double};
/// use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration};
/// # let config_path = Path::new("path/to/config.json");
/// # let vocab_path = Path::new("path/to/vocab.txt");
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = LongT5Config::from_file(config_path);
/// # let longt5_model: LongT5ForConditionalGeneration = LongT5ForConditionalGeneration::new(&vs.root(), &config);
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
/// let encoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
/// let decoder_attention_mask =
/// Tensor::ones(&[batch_size, source_sequence_length], (Int64, device));
///
/// let model_output = no_grad(|| {
/// longt5_model.forward_t(
/// Some(&input_tensor),
/// Some(&encoder_attention_mask),
/// None,
/// Some(&target_tensor),
/// Some(&decoder_attention_mask),
/// None,
/// None,
/// None,
/// false,
/// )
/// });
/// ```
fn forward_t(
&self,
input_ids: Option<&Tensor>,
cache: Cache,
attention_mask: Option<&Tensor>,
_token_type_ids: Option<&Tensor>,
_position_ids: Option<&Tensor>,
_input_embeds: Option<&Tensor>,
encoder_outputs: Option<&Tensor>,
decoder_input_ids: Option<&Tensor>,
train: bool,
) -> Result<LMModelOutput, RustBertError> {
let base_model_output = match cache {
Cache::LongT5Cache(cached_layer_states) => self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
None,
None,
cached_layer_states,
train,
)?,
Cache::None => self.base_model.forward_t(
input_ids,
attention_mask,
encoder_outputs,
decoder_input_ids,
None,
None,
None,
None,
train,
)?,
_ => {
return Err(RustBertError::ValueError(
"Cache not compatible with LongT5 Model".into(),
));
}
};
let lm_logits = if self.tie_word_embeddings {
base_model_output
.decoder_output
.linear::<Tensor>(&self.base_model.embeddings.ws, None)
* (self.model_dim.powf(-0.5))
} else {
base_model_output
.decoder_output
.apply(self.lm_head.as_ref().unwrap())
};
Ok(LMModelOutput {
lm_logits,
cache: Cache::LongT5Cache(base_model_output.next_cache),
})
}
}
/// Container holding a LongT5 model output.
pub type LongT5ModelOutput = T5ModelOutput;
pub struct LongT5Generator {
model: LongT5ForConditionalGeneration,
tokenizer: TokenizerOption,
var_store: nn::VarStore,
generate_config: GenerateConfig,
bos_token_id: Option<i64>,
eos_token_ids: Option<Vec<i64>>,
pad_token_id: Option<i64>,
is_encoder_decoder: bool,
vocab_size: i64,
decoder_start_id: Option<i64>,
max_position_embeddings: i64,
}
impl LongT5Generator {
pub fn new(generate_config: GenerateConfig) -> Result<LongT5Generator, RustBertError> {
let vocab_path = generate_config.vocab_resource.get_local_path()?;
let tokenizer = TokenizerOption::from_file(
ModelType::LongT5,
vocab_path.to_str().unwrap(),
None,
false,
None,
None,
)?;
Self::new_with_tokenizer(generate_config, tokenizer)
}
pub fn new_with_tokenizer(
generate_config: GenerateConfig,
tokenizer: TokenizerOption,
) -> Result<LongT5Generator, RustBertError> {
let config_path = generate_config.config_resource.get_local_path()?;
let weights_path = generate_config.model_resource.get_local_path()?;
let device = generate_config.device;
generate_config.validate();
let mut var_store = nn::VarStore::new(device);
let config = LongT5Config::from_file(config_path);
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
var_store.load(weights_path)?;
let bos_token_id = config.bos_token_id;
let eos_token_ids = Some(match config.eos_token_id {
Some(value) => vec![value],
None => vec![1],
});
let pad_token_id = Some(config.pad_token_id.unwrap_or(0));
let vocab_size = config.vocab_size;
let is_encoder_decoder = true;
let decoder_start_id = pad_token_id;
// longT5 do not have an embedding matrix for position IDs and relies on relative positions instead
let max_position_embeddings = i64::MAX;
Ok(LongT5Generator {
model,
tokenizer,
var_store,
generate_config,
bos_token_id,
eos_token_ids,
pad_token_id,
is_encoder_decoder,
vocab_size,
decoder_start_id,
max_position_embeddings,
})
}
}
impl PrivateLanguageGenerator<LongT5ForConditionalGeneration, T5Vocab, T5Tokenizer>
for LongT5Generator
{
fn get_model(&self) -> &LongT5ForConditionalGeneration {
&self.model
}
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
fn get_var_store_mut(&mut self) -> &mut nn::VarStore {
&mut self.var_store
}
fn get_config(&self) -> &GenerateConfig {
&self.generate_config
}
fn get_bos_id(&self) -> Option<i64> {
self.bos_token_id
}
fn get_eos_ids(&self) -> Option<&Vec<i64>> {
self.eos_token_ids.as_ref()
}
fn get_pad_id(&self) -> Option<i64> {
self.pad_token_id
}
fn is_encoder_decoder(&self) -> bool {
self.is_encoder_decoder
}
fn get_vocab_size(&self) -> i64 {
self.vocab_size
}
fn get_decoder_start_id(&self) -> Option<i64> {
self.decoder_start_id
}
fn get_max_positions_embeddings(&self) -> i64 {
self.max_position_embeddings
}
fn encode(&self, input_ids: &Tensor, attention_mask: Option<&Tensor>) -> Option<Tensor> {
Some(self.get_model().encode(input_ids, attention_mask))
}
fn prepare_inputs_for_generation<'a>(
&self,
input_ids: Tensor,
encoder_outputs: Option<&'a Tensor>,
past: Cache,
attention_mask: Tensor,
) -> PreparedInput<'a> {
match past {
Cache::LongT5Cache(past) => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids.narrow(1, -1, 1)),
prepared_position_ids: None,
prepared_past: Cache::LongT5Cache(past),
},
Cache::None => PreparedInput {
prepared_input: None,
prepared_attention_mask: Some(attention_mask),
prepared_encoder_output: encoder_outputs,
prepared_decoder_input: Some(input_ids),
prepared_position_ids: None,
prepared_past: Cache::LongT5Cache(None),
},
_ => panic!("Cache type incompatible with longT5"),
}
}
fn encode_prompt_text<S>(
&self,
prompt_text: &[S],
max_len: Option<i64>,
pad_token_id: Option<i64>,
) -> Tensor
where
S: AsRef<str> + Sync,
{
let tokens = self._get_tokenizer().encode_list(
prompt_text,
max_len
.map(|max_len| max_len as usize)
.unwrap_or(usize::MAX),
&TruncationStrategy::LongestFirst,
0,
);
let token_ids = tokens
.into_iter()
.map(|tokenized_input| tokenized_input.token_ids)
.collect::<Vec<Vec<i64>>>();
let max_len = token_ids.iter().map(|input| input.len()).max().unwrap();
let pad_token = match pad_token_id {
Some(value) => value,
None => self._get_tokenizer().get_unk_id(),
};
let token_ids = token_ids
.into_iter()
.map(|mut input| {
let temp = vec![pad_token; max_len - input.len()];
input.extend(temp);
input
})
.map(|tokens| Tensor::of_slice(&tokens).to(self.get_var_store().device()))
.collect::<Vec<Tensor>>();
Tensor::stack(&token_ids, 0)
}
fn reorder_cache(
&self,
past: &mut Cache,
encoder_outputs: Option<Tensor>,
beam_indices: &Tensor,
) -> Option<Tensor> {
match past {
Cache::LongT5Cache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
for (self_layer_state, encoder_layer_state) in old_cache.iter_mut() {
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
}
}
None => {}
},
Cache::None => {}
_ => {
panic!("Invalid cache for LongT5 model");
}
};
encoder_outputs
}
}
impl LanguageGenerator<LongT5ForConditionalGeneration, T5Vocab, T5Tokenizer> for LongT5Generator {}

59
src/longt5/mod.rs Normal file
View File

@ -0,0 +1,59 @@
//! # LongT5 (Efficient Text-To-Text Transformer for Long Sequences)
//!
//! Implementation of the LongT5 language model ([LongT5: Efficient Text-To-Text Transformer for Long Sequences](https://arxiv.org/abs/2112.07916) Guo, Ainslie, Uthus, Ontanon, Ni, Sung, Yang, 2021).
//! The base model is implemented in the `longt5_model::LongT5Model` struct. This model includes a language model head: `longt5_model::LongT5ForConditionalGeneration`
//! implementing the common `generation_utils::LMHeadModel` trait shared between the models used for generation (see `pipelines` for more information).
//!
//! # Model set-up and pre-trained weights loading
//!
//! All models expect the following resources:
//! - Configuration file expected to have a structure following the [Transformers library](https://github.com/huggingface/transformers)
//! - Model weights are expected to have a structure and parameter names following the [Transformers library](https://github.com/huggingface/transformers). A conversion using the Python utility scripts is required to convert the `.bin` weights to the `.ot` format.
//! - `T5Tokenizer` using a `spiece.model` sentence piece model
//!
//! Pretrained models for a number of language pairs are available and can be downloaded using RemoteResources.
//!
//! ```no_run
//! # fn main() -> anyhow::Result<()> {
//! #
//! use tch::{nn, Device};
//! # use std::path::PathBuf;
//! use rust_bert::longt5::{LongT5Config, LongT5ForConditionalGeneration};
//! use rust_bert::resources::{LocalResource, ResourceProvider};
//! use rust_bert::Config;
//! use rust_tokenizers::tokenizer::T5Tokenizer;
//!
//! let config_resource = LocalResource {
//! local_path: PathBuf::from("path/to/config.json"),
//! };
//! let sentence_piece_resource = LocalResource {
//! local_path: PathBuf::from("path/to/spiece.model"),
//! };
//! let weights_resource = LocalResource {
//! local_path: PathBuf::from("path/to/model.ot"),
//! };
//! let config_path = config_resource.get_local_path()?;
//! let spiece_path = sentence_piece_resource.get_local_path()?;
//! let weights_path = weights_resource.get_local_path()?;
//!
//! let device = Device::cuda_if_available();
//! let mut vs = nn::VarStore::new(device);
//! let tokenizer = T5Tokenizer::from_file(spiece_path.to_str().unwrap(), true);
//! let config = LongT5Config::from_file(config_path);
//! let longt5_model = LongT5ForConditionalGeneration::new(&vs.root(), &config);
//! vs.load(weights_path)?;
//!
//! # Ok(())
//! # }
//! ```
mod attention;
mod encoder;
mod layer_norm;
mod longt5_model;
pub use attention::LayerState;
pub use longt5_model::{
LongT5Config, LongT5ConfigResources, LongT5ForConditionalGeneration, LongT5Generator,
LongT5Model, LongT5ModelResources, LongT5VocabResources,
};

View File

@ -28,6 +28,7 @@ use crate::fnet::FNetConfig;
use crate::gpt2::Gpt2Config;
use crate::gpt_neo::GptNeoConfig;
use crate::longformer::LongformerConfig;
use crate::longt5::LongT5Config;
use crate::m2m_100::M2M100Config;
use crate::marian::MarianConfig;
use crate::mbart::MBartConfig;
@ -71,6 +72,8 @@ pub enum ModelType {
MobileBert,
#[serde(alias = "t5")]
T5,
#[serde(alias = "longt5")]
LongT5,
#[serde(alias = "albert")]
Albert,
XLNet,
@ -108,6 +111,8 @@ pub enum ConfigOption {
OpenAiGpt(OpenAiGptConfig),
/// T5 configuration
T5(T5Config),
/// LongT5 configuration
LongT5(LongT5Config),
/// Albert configuration
Albert(AlbertConfig),
/// XLNet configuration
@ -187,6 +192,7 @@ impl ConfigOption {
ModelType::Marian => ConfigOption::Marian(MarianConfig::from_file(path)),
ModelType::MobileBert => ConfigOption::MobileBert(MobileBertConfig::from_file(path)),
ModelType::T5 => ConfigOption::T5(T5Config::from_file(path)),
ModelType::LongT5 => ConfigOption::LongT5(LongT5Config::from_file(path)),
ModelType::Albert => ConfigOption::Albert(AlbertConfig::from_file(path)),
ModelType::XLNet => ConfigOption::XLNet(XLNetConfig::from_file(path)),
ModelType::GPT2 => ConfigOption::GPT2(Gpt2Config::from_file(path)),
@ -276,6 +282,7 @@ impl ConfigOption {
.as_ref()
.expect("No label dictionary (id2label) provided in configuration file"),
Self::T5(_) => panic!("T5 does not use a label mapping"),
Self::LongT5(_) => panic!("LongT5 does not use a label mapping"),
Self::OpenAiGpt(_) => panic!("OpenAI GPT does not use a label mapping"),
Self::GPT2(_) => panic!("GPT2 does not use a label mapping"),
Self::GPTNeo(_) => panic!("GPT-Neo does not use a label mapping"),
@ -294,6 +301,7 @@ impl ConfigOption {
Self::Marian(config) => Some(config.max_position_embeddings),
Self::MobileBert(config) => Some(config.max_position_embeddings),
Self::T5(_) => None,
Self::LongT5(_) => None,
Self::Albert(config) => Some(config.max_position_embeddings),
Self::XLNet(_) => None,
Self::GPT2(config) => Some(config.n_positions),
@ -473,7 +481,7 @@ impl TokenizerOption {
lower_case,
)?)
}
ModelType::T5 => {
ModelType::T5 | ModelType::LongT5 => {
if strip_accents.is_some() {
return Err(RustBertError::InvalidConfigurationError(format!(
"Optional input `strip_accents` set to value {} but cannot be used by {:?}",

View File

@ -219,6 +219,7 @@ pub enum Cache {
GPT2Cache(Option<Vec<Tensor>>),
BARTCache(Option<Vec<(Option<BartLayerState>, Option<BartLayerState>)>>),
T5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
LongT5Cache(Option<Vec<(Option<T5LayerState>, Option<T5LayerState>)>>),
XLNetCache(Option<Vec<Option<XLNetLayerState>>>),
ReformerCache(Option<Vec<Option<ReformerLayerState>>>),
ProphetNetCache(Option<Vec<(Option<ProphetNetLayerState>, Option<ProphetNetLayerState>)>>),

View File

@ -73,6 +73,7 @@ use crate::prophetnet::ProphetNetConditionalGenerator;
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use crate::longt5::LongT5Generator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
@ -219,6 +220,8 @@ pub enum SummarizationOption {
Bart(BartGenerator),
/// Summarizer based on T5 model
T5(T5Generator),
/// Summarizer based on LongT5 model
LongT5(LongT5Generator),
/// Summarizer based on ProphetNet model
ProphetNet(ProphetNetConditionalGenerator),
/// Summarizer based on Pegasus model
@ -232,6 +235,9 @@ impl SummarizationOption {
config.into(),
)?)),
ModelType::T5 => Ok(SummarizationOption::T5(T5Generator::new(config.into())?)),
ModelType::LongT5 => Ok(SummarizationOption::LongT5(LongT5Generator::new(
config.into(),
)?)),
ModelType::ProphetNet => Ok(SummarizationOption::ProphetNet(
ProphetNetConditionalGenerator::new(config.into())?,
)),
@ -250,6 +256,7 @@ impl SummarizationOption {
match *self {
Self::Bart(_) => ModelType::Bart,
Self::T5(_) => ModelType::T5,
Self::LongT5(_) => ModelType::LongT5,
Self::ProphetNet(_) => ModelType::ProphetNet,
Self::Pegasus(_) => ModelType::Pegasus,
}
@ -271,6 +278,11 @@ impl SummarizationOption {
.into_iter()
.map(|output| output.text)
.collect(),
Self::LongT5(ref model) => model
.generate(prompt_texts, None)
.into_iter()
.map(|output| output.text)
.collect(),
Self::ProphetNet(ref model) => model
.generate(prompt_texts, None)
.into_iter()

View File

@ -12,7 +12,7 @@
use crate::common::dropout::Dropout;
use crate::common::embeddings::process_ids_embeddings_pair;
use crate::common::kind::get_negative_infinity;
use crate::common::kind::get_min;
use crate::prophetnet::attention::{
compute_all_stream_relative_buckets, LayerState, ProphetNetAttention, ProphetNetFeedForward,
ProphetNetNgramAttention,
@ -26,7 +26,7 @@ use tch::{nn, Device, Kind, Tensor};
fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device, kind: Kind) -> Tensor {
let left_block = Tensor::ones(&[ngram, sequence_length, sequence_length], (kind, device))
* get_negative_infinity(kind).unwrap();
* get_min(kind).unwrap();
let right_block = left_block.copy();
for stream_idx in 0..ngram {
let _ = right_block.get(stream_idx).fill_diagonal_(0, false);
@ -515,7 +515,7 @@ impl ProphetNetDecoder {
let causal_mask = Tensor::full(
&[sequence_length, sequence_length],
get_negative_infinity(hidden_states.kind()).unwrap(),
get_min(hidden_states.kind()).unwrap(),
(hidden_states.kind(), hidden_states.device()),
)
.triu_(1);

View File

@ -43,6 +43,37 @@ impl LayerState {
}
}
pub fn get_relative_position_bucket(
relative_position: &Tensor,
bidirectional: bool,
num_buckets: i64,
max_distance: i64,
) -> Tensor {
let n = -relative_position;
let mut num_buckets = num_buckets;
let mut ret = n.zeros_like();
let n = if bidirectional {
num_buckets /= 2;
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
n.abs()
} else {
n.max_other(&n.zeros_like())
};
let max_exact = num_buckets / 2;
let is_small = n.lt(max_exact);
let value_if_large: Tensor = ((n.to_kind(Kind::Float) / max_exact as f64).log2()
/ (max_distance as f64 / max_exact as f64).log2()
* (num_buckets - max_exact) as f64)
.to_kind(Kind::Int64)
+ max_exact;
let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1));
ret += n.where_self(&is_small, &value_if_large);
ret
}
#[derive(Debug)]
pub struct T5Attention {
is_decoder: bool,
@ -142,7 +173,7 @@ impl T5Attention {
train: bool,
) -> (Tensor, Option<Tensor>, Option<Tensor>, Option<LayerState>) {
let input_size = hidden_states.size();
let (bs, seq_length, _) = (input_size[0], input_size[1], input_size[2]);
let (bs, seq_length) = (input_size[0], input_size[1]);
let real_seq_length = if layer_state.is_some() {
match query_length {
@ -245,44 +276,12 @@ impl T5Attention {
(context, attention_weights, position_bias, layer_state)
}
fn get_relative_position_bucket(
&self,
relative_position: &Tensor,
bidirectional: bool,
num_buckets: i64,
max_distance: i64,
) -> Tensor {
let n = -relative_position;
let mut num_buckets = num_buckets;
let mut ret = n.zeros_like();
let n = if bidirectional {
num_buckets /= 2;
ret += n.lt(0).to_kind(Kind::Int64) * num_buckets;
n.abs()
} else {
n.max_other(&n.zeros_like())
};
let max_exact = num_buckets / 2;
let is_small = n.lt(max_exact);
let value_if_large: Tensor = ((n.to_kind(Kind::Float) / max_exact as f64).log2()
/ (max_distance as f64 / max_exact as f64).log2()
* (num_buckets - max_exact) as f64)
.to_kind(Kind::Int64)
+ max_exact;
let value_if_large = value_if_large.min_other(&value_if_large.full_like(num_buckets - 1));
ret += n.where_self(&is_small, &value_if_large);
ret
}
fn compute_bias(&self, q_len: i64, k_len: i64, device: Device) -> Tensor {
let context_position = Tensor::arange(q_len, (Kind::Int64, device)).unsqueeze(1);
let memory_position = Tensor::arange(k_len, (Kind::Int64, device)).unsqueeze(0);
let relative_position = memory_position - context_position;
let rp_bucket = self.get_relative_position_bucket(
let rp_bucket = get_relative_position_bucket(
&relative_position,
self.is_bidirectional,
self.relative_attention_num_buckets,

View File

@ -17,20 +17,21 @@ use crate::t5::attention::{LayerState, T5LayerCrossAttention, T5LayerSelfAttenti
use crate::t5::layer_norm::T5LayerNorm;
use crate::t5::t5_model::FeedForwardProj;
use crate::t5::T5Config;
use crate::Activation::gelu_new;
use crate::Activation::{gelu_new, relu};
use crate::RustBertError;
use std::borrow::{Borrow, BorrowMut};
use tch::nn::LinearConfig;
use tch::{nn, Kind, Scalar, Tensor};
pub struct T5DenseReluDense {
pub struct T5DenseActDense {
wi: nn::Linear,
wo: nn::Linear,
dropout: Dropout,
activation_function: TensorFunction,
}
impl T5DenseReluDense {
pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseReluDense
impl T5DenseActDense {
pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseActDense
where
P: Borrow<nn::Path<'p>>,
{
@ -42,29 +43,36 @@ impl T5DenseReluDense {
let wi = nn::linear(p / "wi", config.d_model, config.d_ff, linear_config);
let wo = nn::linear(p / "wo", config.d_ff, config.d_model, linear_config);
let dropout = Dropout::new(config.dropout_rate);
let activation_function = match config.feed_forward_proj {
None | Some(FeedForwardProj::Relu) => relu.get_function(),
Some(FeedForwardProj::GatedGelu) => gelu_new.get_function(),
};
T5DenseReluDense { wi, wo, dropout }
T5DenseActDense {
wi,
wo,
dropout,
activation_function,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
hidden_states
.apply(&self.wi)
.relu()
self.activation_function.get_fn()(&hidden_states.apply(&self.wi))
.apply_t(&self.dropout, train)
.apply(&self.wo)
}
}
pub struct T5DenseGatedGeluDense {
pub struct T5DenseGatedActDense {
wi_0: nn::Linear,
wi_1: nn::Linear,
wo: nn::Linear,
dropout: Dropout,
activation: TensorFunction,
activation_function: TensorFunction,
}
impl T5DenseGatedGeluDense {
pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseGatedGeluDense
impl T5DenseGatedActDense {
pub fn new<'p, P>(p: P, config: &T5Config) -> T5DenseGatedActDense
where
P: Borrow<nn::Path<'p>>,
{
@ -77,19 +85,22 @@ impl T5DenseGatedGeluDense {
let wi_1 = nn::linear(p / "wi_1", config.d_model, config.d_ff, linear_config);
let wo = nn::linear(p / "wo", config.d_ff, config.d_model, linear_config);
let dropout = Dropout::new(config.dropout_rate);
let activation = gelu_new.get_function();
let activation_function = match config.feed_forward_proj {
None | Some(FeedForwardProj::Relu) => relu.get_function(),
Some(FeedForwardProj::GatedGelu) => gelu_new.get_function(),
};
T5DenseGatedGeluDense {
T5DenseGatedActDense {
wi_0,
wi_1,
wo,
dropout,
activation,
activation_function,
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
let hidden_gelu = self.activation.get_fn()(&hidden_states.apply(&self.wi_0));
let hidden_gelu = self.activation_function.get_fn()(&hidden_states.apply(&self.wi_0));
let hidden_linear = hidden_states.apply(&self.wi_1);
(hidden_gelu * hidden_linear)
.apply_t(&self.dropout, train)
@ -98,8 +109,8 @@ impl T5DenseGatedGeluDense {
}
pub enum T5FeedForwardLayer {
T5DenseReluDense(T5DenseReluDense),
T5DenseGatedGeluDense(T5DenseGatedGeluDense),
T5DenseActDense(T5DenseActDense),
T5DenseGatedActDense(T5DenseGatedActDense),
}
impl T5FeedForwardLayer {
@ -109,20 +120,18 @@ impl T5FeedForwardLayer {
{
match config.feed_forward_proj.unwrap_or(FeedForwardProj::Relu) {
FeedForwardProj::Relu => {
T5FeedForwardLayer::T5DenseReluDense(T5DenseReluDense::new(p, config))
T5FeedForwardLayer::T5DenseActDense(T5DenseActDense::new(p, config))
}
FeedForwardProj::GatedGelu => {
T5FeedForwardLayer::T5DenseGatedGeluDense(T5DenseGatedGeluDense::new(p, config))
T5FeedForwardLayer::T5DenseGatedActDense(T5DenseGatedActDense::new(p, config))
}
}
}
pub fn forward_t(&self, hidden_states: &Tensor, train: bool) -> Tensor {
match self {
T5FeedForwardLayer::T5DenseReluDense(ref model) => {
model.forward_t(hidden_states, train)
}
T5FeedForwardLayer::T5DenseGatedGeluDense(ref model) => {
T5FeedForwardLayer::T5DenseActDense(ref model) => model.forward_t(hidden_states, train),
T5FeedForwardLayer::T5DenseGatedActDense(ref model) => {
model.forward_t(hidden_states, train)
}
}
@ -217,7 +226,7 @@ impl T5Block {
}
}
fn clamp_hidden_states(hidden_states: Tensor) -> Tensor {
pub(crate) fn clamp_hidden_states(hidden_states: Tensor) -> Tensor {
if (hidden_states.kind() != Kind::Float) & bool::from(hidden_states.isinf().any()) {
let clamp_value = match hidden_states.kind() {
Kind::Half => half::f16::MAX.to_f64() - 1000.,

View File

@ -54,6 +54,10 @@ mod layer_norm;
mod t5_model;
pub use attention::LayerState;
pub(crate) use attention::{get_relative_position_bucket, T5Attention, T5LayerCrossAttention};
pub(crate) use encoder::{T5Block, T5BlockOutput, T5LayerFF, T5StackOutput};
pub(crate) use layer_norm::T5LayerNorm;
pub(crate) use t5_model::{FeedForwardProj, TaskSpecificParams};
pub use t5_model::{
T5Config, T5ConfigResources, T5ForConditionalGeneration, T5ForSentenceEmbeddings, T5Generator,
T5Model, T5ModelOutput, T5ModelResources, T5Prefix, T5SourceLanguages, T5TargetLanguages,

View File

@ -147,7 +147,7 @@ pub struct T5Config {
pub vocab_size: i64,
pub feed_forward_proj: Option<FeedForwardProj>,
pub tie_word_embeddings: Option<bool>,
task_specific_params: Option<TaskSpecificParams>,
pub task_specific_params: Option<TaskSpecificParams>,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}
@ -250,7 +250,7 @@ impl T5Model {
///
/// # Arguments
///
/// * `p` - Variable store path for the root of the BART model
/// * `p` - Variable store path for the root of the T5 model
/// * `config` - `T5Config` object defining the model architecture
///
/// # Example

View File

@ -17,7 +17,7 @@ fn deberta_v2_masked_lm() -> anyhow::Result<()> {
DebertaV2ConfigResources::DEBERTA_V3_BASE,
));
let config_path = config_resource.get_local_path()?;
let device = Device::cuda_if_available();
let device = Device::Cpu;
let vs = nn::VarStore::new(device);
let mut config = DebertaV2Config::from_file(config_path);
config.output_attentions = Some(true);

64
tests/longt5.rs Normal file
View File

@ -0,0 +1,64 @@
use rust_bert::longt5::{LongT5ConfigResources, LongT5ModelResources, LongT5VocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::resources::RemoteResource;
#[test]
fn test_summarization_longt5() -> anyhow::Result<()> {
// Set-up translation model
let summarization_config = SummarizationConfig {
model_type: ModelType::LongT5,
model_resource: Box::new(RemoteResource::from_pretrained(
LongT5ModelResources::TGLOBAL_BASE_BOOK_SUMMARY,
)),
config_resource: Box::new(RemoteResource::from_pretrained(
LongT5ConfigResources::TGLOBAL_BASE_BOOK_SUMMARY,
)),
vocab_resource: Box::new(RemoteResource::from_pretrained(
LongT5VocabResources::TGLOBAL_BASE_BOOK_SUMMARY,
)),
merges_resource: None,
min_length: 30,
max_length: Some(200),
early_stopping: true,
num_beams: 2,
length_penalty: 2.0,
..Default::default()
};
let model = SummarizationModel::new(summarization_config)?;
let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists \
from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team \
from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b, \
a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's \
habitable zone not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke, \
used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet \
passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water, \
weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere \
contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software \
and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet, \
but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth. \
\"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\" \
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the HarvardSmithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope scheduled for launch in 2021 and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];
let output = model.summarize(&input);
assert_eq! (
output[0],
" The first discovery of water on an exoplanet, K2-18b, comes from two different sources: scientists \
from the University of Montreal and a team from University College London. The scientists found that certain \
wavelengths of light absorbed by water weakened when the planet was in the way of Earth, indicating that the \
planet has an atmosphere. The Montreal team analyzed their own results using their own software, and confirmed \
their conclusion. This is the first such discovery in a planet in its habitable zone - not too hot and not too cold for liquid water to exist."
);
Ok(())
}

View File

@ -10,7 +10,8 @@ from torch import Tensor
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("source_file", help="Absolute path to the Pytorch weights file to convert")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings / language model head")
parser.add_argument("--skip_embeddings", action="store_true", help="Skip shared embeddings")
parser.add_argument("--skip_lm_head", action="store_true", help="Skip language model head")
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument("--suffix", action="store_true", help="Split weight names on '.' and keep only last part")
args = parser.parse_args()
@ -24,7 +25,17 @@ if __name__ == "__main__":
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {"lm_head.weight", "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"}:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight"
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k